In [2]:
import torch
import numpy as np
import torch.nn as nn

class onering_conv_layer(nn.Module):
    """The convolutional layer on icosahedron discretized sphere using 
    1-ring filter
    
    Parameters:
            in_feats (int) - - input features/channels
            out_feats (int) - - output features/channels
            
    Input: 
        N x in_feats tensor
    Return:
        N x out_feats tensor
    """  
    def __init__(self, in_feats, out_feats, neigh_orders, neigh_indices=None, neigh_weights=None):
        super(onering_conv_layer, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.neigh_orders = neigh_orders
        self.pool = pool_layer(self.neigh_orders, pooling_type='mean')
        self.weight = nn.Linear(7 * in_feats, out_feats)
        # self.norm = nn.BatchNorm1d(out_feats, momentum=0.15, affine=True, track_running_stats=False)
        # self.relu = nn.LeakyReLU(0.2, inplace=True)
        # self.dropout = nn.Dropout(0.7)
        
    def forward(self, x):
       
        mat = x[self.neigh_orders].view(len(x), 7*self.in_feats)
                
        out_features = self.weight(mat)
        # out_features = self.norm(out_features)
        # out_features = self.relu(out_features)
        # out_features = self.pool(out_features)
        # out_features = self.dropout(out_features)
        
        return out_features
    
    

class pool_layer(nn.Module):
    """
    The pooling layer on icosahedron discretized sphere using 1-ring filter
    
    Input: 
        N x D tensor
    Return:
        ((N+6)/4) x D tensor
    
    """  

    def __init__(self, neigh_orders, pooling_type='mean'):
        super(pool_layer, self).__init__()

        self.neigh_orders = neigh_orders
        self.pooling_type = pooling_type
        
    def forward(self, x):
       
        num_nodes = int((x.size()[0]+6)/4)
        feat_num = x.size()[1]
        x = x[self.neigh_orders[0:num_nodes*7]].view(num_nodes, feat_num, 7)
        if self.pooling_type == "mean":
            x = torch.mean(x, 2)
        if self.pooling_type == "max":
            x = torch.max(x, 2)
            assert(x[0].size() == torch.Size([num_nodes, feat_num]))
            return x[0], x[1]
        
        assert(x.size() == torch.Size([num_nodes, feat_num]))
                
        return x

In [3]:
import scipy.io as sio
import numpy as np
def Get_neighs_order(rotated=0):
    neigh_orders_163842 = get_neighs_order(163842, rotated)
    neigh_orders_40962 = get_neighs_order(40962, rotated)
    neigh_orders_10242 = get_neighs_order(10242, rotated)
    neigh_orders_2562 = get_neighs_order(2562, rotated)
    neigh_orders_642 = get_neighs_order(642, rotated)
    neigh_orders_162 = get_neighs_order(162, rotated)
    neigh_orders_42 = get_neighs_order(42, rotated)
    neigh_orders_12 = get_neighs_order(12, rotated)
    
    return neigh_orders_163842, neigh_orders_40962, neigh_orders_10242,\
        neigh_orders_2562, neigh_orders_642, neigh_orders_162, neigh_orders_42, neigh_orders_12
  
def get_neighs_order(n_vertex, rotated=0):
    adj_mat_order = sio.loadmat(abspath +'/neigh_indices/adj_mat_order_'+ \
                                str(n_vertex) +'_rotated_' + str(rotated) + '.mat')
    adj_mat_order = adj_mat_order['adj_mat_order']
    neigh_orders = np.zeros((len(adj_mat_order), 7))
    neigh_orders[:,0:6] = adj_mat_order-1
    neigh_orders[:,6] = np.arange(len(adj_mat_order))
    neigh_orders = np.ravel(neigh_orders).astype(np.int64)
    
    return neigh_orders

def Get_upconv_index(rotated=0):
    
    upconv_top_index_163842, upconv_down_index_163842 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_163842_rotated_' + str(rotated) + '.mat')
    upconv_top_index_40962, upconv_down_index_40962 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_40962_rotated_' + str(rotated) + '.mat')
    upconv_top_index_10242, upconv_down_index_10242 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_10242_rotated_' + str(rotated) + '.mat')
    upconv_top_index_2562, upconv_down_index_2562 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_2562_rotated_' + str(rotated) + '.mat')
    upconv_top_index_642, upconv_down_index_642 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_642_rotated_' + str(rotated) + '.mat')
    upconv_top_index_162, upconv_down_index_162 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_162_rotated_' + str(rotated) + '.mat')
    
    #TODO: return tuples of each level
    return upconv_top_index_163842, upconv_down_index_163842, upconv_top_index_40962, upconv_down_index_40962, upconv_top_index_10242, upconv_down_index_10242,  upconv_top_index_2562, upconv_down_index_2562,  upconv_top_index_642, upconv_down_index_642, upconv_top_index_162, upconv_down_index_162


def get_upconv_index(order_path):  
    adj_mat_order = sio.loadmat(order_path)
    adj_mat_order = adj_mat_order['adj_mat_order']
    adj_mat_order = adj_mat_order -1
    nodes = len(adj_mat_order)
    next_nodes = int((len(adj_mat_order)+6)/4)
    upconv_top_index = np.zeros(next_nodes).astype(np.int64) - 1
    for i in range(next_nodes):
        upconv_top_index[i] = i * 7 + 6
    upconv_down_index = np.zeros((nodes-next_nodes) * 2).astype(np.int64) - 1
    for i in range(next_nodes, nodes):
        raw_neigh_order = adj_mat_order[i]
        parent_nodes = raw_neigh_order[raw_neigh_order < next_nodes]
        assert(len(parent_nodes) == 2)
        for j in range(2):
            parent_neigh = adj_mat_order[parent_nodes[j]]
            index = np.where(parent_neigh == i)[0][0]
            upconv_down_index[(i-next_nodes)*2 + j] = parent_nodes[j] * 7 + index
    
    return upconv_top_index, upconv_down_index


def get_upsample_order(n_vertex):
    n_last = int((n_vertex+6)/4)
    neigh_orders = get_neighs_order(abspath+'/neigh_indices/adj_mat_order_'+ str(n_vertex) +'_rotated_0.mat')
    neigh_orders = neigh_orders.reshape(n_vertex, 7)
    neigh_orders = neigh_orders[n_last:,:]
    row, col = (neigh_orders < n_last).nonzero()
    assert len(row) == (n_vertex - n_last)*2, "len(row) == (n_vertex - n_last)*2, error!"
    
    u, indices, counts = np.unique(row, return_index=True, return_counts=True)
    assert len(u) == n_vertex - n_last, "len(u) == n_vertex - n_last, error"
    assert u.min() == 0 and u.max() == n_vertex-n_last-1, "u.min() == 0 and u.max() == n_vertex-n_last-1, error"
    assert (indices == np.asarray(list(range(n_vertex - n_last))) * 2).sum() == n_vertex - n_last, "(indices == np.asarray(list(range(n_vertex - n_last))) * 2).sum() == n_vertex - n_last, error"
    assert (counts == 2).sum() == n_vertex - n_last, "(counts == 2).sum() == n_vertex - n_last, error"
    
    upsample_neighs_order = neigh_orders[row, col]
    
    return upsample_neighs_order


abspath = 'C:/Users/DELL/Desktop/kaiti/SphericalUNetPackage/sphericalunet/utils'
neigh_orders = Get_neighs_order()
neigh_orders = neigh_orders[1:]
a, b, upconv_top_index_40962, upconv_down_index_40962, upconv_top_index_10242, upconv_down_index_10242,  upconv_top_index_2562, upconv_down_index_2562,  upconv_top_index_642, upconv_down_index_642, upconv_top_index_162, upconv_down_index_162 = Get_upconv_index()

In [4]:
import scipy.io as sio 
import torch.nn as nn
class down_block(nn.Module):
    """
    downsampling block in spherical unet
    mean pooling => (conv => BN => ReLU) * 2
    
    """
    def __init__(self, conv_layer, in_ch, out_ch, neigh_orders, pool_neigh_orders, first = False):
        super(down_block, self).__init__()


#        Batch norm version
        if first:
            self.block = nn.Sequential(
                conv_layer(in_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),
                conv_layer(out_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True)
        )
            
        else:
            self.block = nn.Sequential(
                pool_layer(pool_neigh_orders, 'mean'),
                conv_layer(in_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),
                conv_layer(out_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),
        )


    def forward(self, x):
        # batch norm version
        x = self.block(x)
        
        return x

class upconv_layer(nn.Module):
    """
    The transposed convolution layer on icosahedron discretized sphere using 1-ring filter
    
    Input: 
        N x in_feats, tensor
    Return:
        ((Nx4)-6) x out_feats, tensor
    
    """  

    def __init__(self, in_feats, out_feats, upconv_top_index, upconv_down_index):
        super(upconv_layer, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.upconv_top_index = upconv_top_index
        self.upconv_down_index = upconv_down_index
        self.weight = nn.Linear(in_feats, 7 * out_feats)
        
    def forward(self, x):
       
        raw_nodes = x.size()[0]
        new_nodes = int(raw_nodes*4 - 6)
        x = self.weight(x)
        x = x.view(len(x) * 7, self.out_feats)
        x1 = x[self.upconv_top_index]
        assert(x1.size() == torch.Size([raw_nodes, self.out_feats]))
        x2 = x[self.upconv_down_index].view(-1, self.out_feats, 2)
        x = torch.cat((x1,torch.mean(x2, 2)), 0)
        assert(x.size() == torch.Size([new_nodes, self.out_feats]))
        return x
    
    
class up_block(nn.Module):
    """Define the upsamping block in spherica uent
    upconv => (conv => BN => ReLU) * 2
    
    Parameters:
            in_ch (int) - - input features/channels
            out_ch (int) - - output features/channels    
            neigh_orders (tensor, int)  - - conv layer's filters' neighborhood orders
            
    """    
    def __init__(self, conv_layer, in_ch, out_ch, neigh_orders, upconv_top_index, upconv_down_index):
        super(up_block, self).__init__()
        
        self.up = upconv_layer(in_ch, out_ch, upconv_top_index, upconv_down_index)
        
        # batch norm version
        self.double_conv = nn.Sequential(
             conv_layer(in_ch, out_ch, neigh_orders),
             nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
             nn.LeakyReLU(0.2, inplace=True),
             conv_layer(out_ch, out_ch, neigh_orders),
             nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
             nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x1, x2):
        
        x1 = self.up(x1)
        x = torch.cat((x1, x2), 1) 
        x = self.double_conv(x)

        return x
    
class Unet_40k(nn.Module):
    """Define the Spherical UNet structure

    """    
    def __init__(self, in_ch, out_ch):
        """ Initialize the Spherical UNet.

        Parameters:
            in_ch (int) - - input features/channels
            out_ch (int) - - output features/channels
        """
        super(Unet_40k, self).__init__()

        #neigh_indices_10242, neigh_indices_2562, neigh_indices_642, neigh_indices_162, neigh_indices_42 = Get_indices_order()
        #neigh_orders_10242, neigh_orders_2562, neigh_orders_642, neigh_orders_162, neigh_orders_42, neigh_orders_12 = Get_neighs_order()
        
        neigh_orders = Get_neighs_order()
        neigh_orders = neigh_orders[1:]
        a, b, upconv_top_index_40962, upconv_down_index_40962, upconv_top_index_10242, upconv_down_index_10242,  upconv_top_index_2562, upconv_down_index_2562,  upconv_top_index_642, upconv_down_index_642, upconv_top_index_162, upconv_down_index_162 = Get_upconv_index() 

        chs = [in_ch, 32, 64, 128, 256, 512]
        
        conv_layer = onering_conv_layer

        self.down1 = down_block(conv_layer, chs[0], chs[1], neigh_orders[0], None, True)
        self.down2 = down_block(conv_layer, chs[1], chs[2], neigh_orders[1], neigh_orders[0])
        self.down3 = down_block(conv_layer, chs[2], chs[3], neigh_orders[2], neigh_orders[1])
        self.down4 = down_block(conv_layer, chs[3], chs[4], neigh_orders[3], neigh_orders[2])
        self.down5 = down_block(conv_layer, chs[4], chs[5], neigh_orders[4], neigh_orders[3])
      
        self.up1 = up_block(conv_layer, chs[5], chs[4], neigh_orders[3], upconv_top_index_642, upconv_down_index_642)
        self.up2 = up_block(conv_layer, chs[4], chs[3], neigh_orders[2], upconv_top_index_2562, upconv_down_index_2562)
        self.up3 = up_block(conv_layer, chs[3], chs[2], neigh_orders[1], upconv_top_index_10242, upconv_down_index_10242)
        self.up4 = up_block(conv_layer, chs[2], chs[1], neigh_orders[0], upconv_top_index_40962, upconv_down_index_40962)
        
        self.outc = nn.Sequential(
                nn.Linear(chs[1], out_ch)
                )
                
        
    def forward(self, x):
        x2 = self.down1(x)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2) # 40962 * 32
        
        x = self.outc(x) # 40962 * 36
        return x

In [33]:
import torch
import argparse
import torchvision
import numpy as np
import glob
import os

# from model import Unet_40k, Unet_160k
from sphericalunet.utils.vtk import read_vtk, write_vtk, resample_label
# from sphericalunet.utils.utils import get_par_36_to_fs_vec
from sphericalunet.utils.interp_numpy import resampleSphereSurf



def get_par_fs_to_36():
    """ Preprocessing for parcellatiion label """
    file = r'C:\Users\DELL\Desktop\kaiti\Spherical_U-Net\neigh_indices\template.vtk'
    data = read_vtk(file)
    par_fs = data['par_fs']
    par_fs_label = np.sort(np.unique(par_fs))
    par_dic = {}
    for i in range(len(par_fs_label)):
        par_dic[par_fs_label[i]] = i
    return par_dic


def get_par_36_to_fs_vec():
    """ Preprocessing for parcellatiion label """
    file = r'C:\Users\DELL\Desktop\kaiti\Spherical_U-Net\neigh_indices\template.vtk'
    data = read_vtk(file)
    par_fs = data['par_fs']
    par_fs_vec = data['par_fs_vec']
    par_fs_to_36 = get_par_fs_to_36()
    par_36_to_fs = dict(zip(par_fs_to_36.values(), par_fs_to_36.keys()))
    par_36_to_fs_vec = {}
    for i in range(len(par_fs_to_36)):
        par_36_to_fs_vec[i] = par_fs_vec[np.where(par_fs == par_36_to_fs[i])[0][0]]
    return par_36_to_fs_vec


# def get_par_36_to_fs_vec():
#     """ Preprocessing for parcellatiion label """
#     file = '/media/fenqiang/DATA/unc/Data/NITRC/data/left/train/MNBCP107842_809.lh.SphereSurf.Orig.Resample.vtk'
#     data = read_vtk(file)
#     par_fs = data['par_fs']
#     par_fs_vec = data['par_fs_vec']
#     par_fs_to_36 = get_par_fs_to_36()
#     par_36_to_fs = dict(zip(par_fs_to_36.values(), par_fs_to_36.keys()))
#     par_36_to_fs_vec = {}
#     for i in range(len(par_fs_to_36)):
#         par_36_to_fs_vec[i] = par_fs_vec[np.where(par_fs == par_36_to_fs[i])[0][0]]
#     return par_36_to_fs_vec

class BrainSphere(torch.utils.data.Dataset):

    def __init__(self, root1):

        self.files = sorted(glob.glob(os.path.join(root1, '*.vtk')))    

    def __getitem__(self, index):
        file = self.files[index]
        data = read_vtk(file)
       
        return data, file

    def __len__(self):
        return len(self.files)


def inference(curv, sulc, model):
    feats =torch.cat((curv, sulc), 1)
    feat_max = [1.2, 13.7]
    for i in range(feats.shape[1]):
        feats[:,i] = feats[:, i]/feat_max[i]
    feats = feats.to(device)
    with torch.no_grad():
        prediction = model(feats)
    pred = prediction.max(1)[1]
    pred = pred.cpu().numpy()
    return pred





if __name__ == "__main__":    
    # parser = argparse.ArgumentParser(description='Predict the parcellation maps with 36 regions from the input surfaces',
    #                                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # parser.add_argument('--hemisphere', '-hemi', default='left',
    #                     choices=['left', 'right'], 
    #                     help="Specify the hemisphere for parcellation, left or right.")
    # parser.add_argument('--level', '-l', default='7',
    #                     choices=['7', '8'],
    #                     help="Specify the level of the surfaces' resolution. Generally, level 7 with 40962 vertices is sufficient, level 8 with 163842 vertices is more accurate but slower.")
    # parser.add_argument('--input', '-i', metavar='INPUT',
    #                     help='filename of input surface')
    # parser.add_argument('--output', '-o',  default='[input].parc.vtk', metavar='OUTPUT',
    #                     help='Filename of ouput surface.')
    # parser.add_argument('--device', default='GPU', choices=['GPU', 'CPU'], 
    #                     help='the device for running the model.')

    # args =  parser.parse_args()
    in_file = 'C:/Users/DELL/Desktop/kaiti/Spherical_U-Net/examples/left_hemisphere/40962/test1.lh.40k.vtk'
    out_file = 'C:/Users/DELL/Desktop/kaiti/Spherical_U-Net/examples/left_hemisphere/40962/test_output.vtk'
    hemi = 'left'
    level = 7   
    device = torch.device('cuda:0')


    model = Unet_40k(2, 36)
    # model_path = '40k_curv_sulc.pkl'
    n_vertices = 40962

    
    model_path = r'C:\Users\DELL\Desktop\kaiti\Spherical_U-Net\trained_models\left_hemi_40k_curv_sulc.pkl'
    model.to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
       
    par_fs_to_36 = get_par_fs_to_36()
    par_36_to_fs = dict(zip(par_fs_to_36.values(), par_fs_to_36.keys()))
    par_36_to_fs_vec = get_par_36_to_fs_vec()
    print(par_36_to_fs_vec)

    # template = read_vtk('C:/Users/DELL/Desktop/kaiti/SphericalUNetPackage/sphericalunet/utils/neigh_indices/sphere_' + str(n_vertices) + '_rotated_0.vtk')
    # if in_file is not None:
    #     orig_surf = read_vtk(in_file)
    #     curv_temp = orig_surf['curv']
    #     if len(curv_temp) != n_vertices:
    #         sucu = resampleSphereSurf(orig_surf['vertices'], template['vertices'], 
    #                                   np.concatenate((orig_surf['sulc'][:,np.newaxis], 
    #                                                   orig_surf['curv'][:,np.newaxis]),
    #                                                  axis=1))
    #         sulc = sucu[:,0]
    #         curv = sucu[:,1]
    #     else:
    #          curv = orig_surf['curv'][0:n_vertices]
    #          sulc = orig_surf['sulc'][0:n_vertices]
        
    #     curv = torch.from_numpy(curv).unsqueeze(1) 
    #     sulc = torch.from_numpy(sulc).unsqueeze(1)
        
    #     pred = inference(curv, sulc, model)
    #     print(pred)
    #     # pred = par_36_to_fs_vec[pred]
        
    #     # orig_lbl = resample_label(template['vertices'], orig_surf['vertices'], pred)
        
    #     # orig_surf['par_fs_vec'] = orig_lbl
    #     # write_vtk(orig_surf, out_file)
   
    # par_fs_to_36 = get_par_fs_to_36()


    if in_file is not None:
        data = read_vtk(in_file)
        curv_temp = data['curv']
        if len(curv_temp) != n_vertices:
            raise NotImplementedError('Input surfaces level is not consistent with the level '+ level + ' that the model was trained on.')
        curv = torch.from_numpy(data['curv'][0:n_vertices]).unsqueeze(1) # use curv data with 40k vertices
        sulc = torch.from_numpy(data['sulc'][0:n_vertices]).unsqueeze(1) # use sulc data with 40k vertices
        pred = inference(curv, sulc, model)
        print(pred)
        data['par_fs'] = np.array([par_36_to_fs[i] for i in pred])
        data['par_fs_vec'] = np.array([par_36_to_fs_vec[i] for i in pred])
        print(data['par_fs'])
        print(data['par_fs_vec'])
        write_vtk(data, out_file)

{0: array([100,  25,   0]), 1: array([220,  20,  10]), 2: array([220,  20,  20]), 3: array([220,  60,  20]), 4: array([ 80, 160,  20]), 5: array([25,  5, 25]), 6: array([255, 192,  32]), 7: array([ 25, 100,  40]), 8: array([120,  70,  50]), 9: array([35, 75, 50]), 10: array([ 20, 100,  50]), 11: array([160, 100,  50]), 12: array([120, 100,  60]), 13: array([ 20, 220,  60]), 14: array([ 60, 220,  60]), 15: array([200,  35,  75]), 16: array([100,   0, 100]), 17: array([220,  20, 100]), 18: array([180,  40, 120]), 19: array([ 75,  50, 125]), 20: array([ 80,  20, 140]), 21: array([140,  20, 140]), 22: array([ 20,  30, 140]), 23: array([225, 140, 140]), 24: array([ 20, 180, 140]), 25: array([220, 180, 140]), 26: array([180, 220, 140]), 27: array([125, 100, 160]), 28: array([ 20, 220, 160]), 29: array([ 70,  20, 170]), 30: array([160, 140, 180]), 31: array([150, 150, 200]), 32: array([ 60,  20, 220]), 33: array([220,  60, 220]), 34: array([220, 180, 220]), 35: array([140, 220, 220])}
[32 24 