In [2]:
import torch
import numpy as np
import torch.nn as nn

class onering_conv_layer(nn.Module):
    """The convolutional layer on icosahedron discretized sphere using 
    1-ring filter
    
    Parameters:
            in_feats (int) - - input features/channels
            out_feats (int) - - output features/channels
            
    Input: 
        N x in_feats tensor
    Return:
        N x out_feats tensor
    """  
    def __init__(self, in_feats, out_feats, neigh_orders, neigh_indices=None, neigh_weights=None):
        super(onering_conv_layer, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.neigh_orders = neigh_orders
        self.pool = pool_layer(self.neigh_orders, pooling_type='mean')
        self.weight = nn.Linear(7 * in_feats, out_feats)
        # self.norm = nn.BatchNorm1d(out_feats, momentum=0.15, affine=True, track_running_stats=False)
        # self.relu = nn.LeakyReLU(0.2, inplace=True)
        # self.dropout = nn.Dropout(0.7)
        
    def forward(self, x):
       
        mat = x[self.neigh_orders].view(len(x), 7*self.in_feats)
                
        out_features = self.weight(mat)
        # out_features = self.norm(out_features)
        # out_features = self.relu(out_features)
        # out_features = self.pool(out_features)
        # out_features = self.dropout(out_features)
        
        return out_features
    
    

class pool_layer(nn.Module):
    """
    The pooling layer on icosahedron discretized sphere using 1-ring filter
    
    Input: 
        N x D tensor
    Return:
        ((N+6)/4) x D tensor
    
    """  

    def __init__(self, neigh_orders, pooling_type='mean'):
        super(pool_layer, self).__init__()

        self.neigh_orders = neigh_orders
        self.pooling_type = pooling_type
        
    def forward(self, x):
       
        num_nodes = int((x.size()[0]+6)/4)
        feat_num = x.size()[1]
        x = x[self.neigh_orders[0:num_nodes*7]].view(num_nodes, feat_num, 7)
        if self.pooling_type == "mean":
            x = torch.mean(x, 2)
        if self.pooling_type == "max":
            x = torch.max(x, 2)
            assert(x[0].size() == torch.Size([num_nodes, feat_num]))
            return x[0], x[1]
        
        assert(x.size() == torch.Size([num_nodes, feat_num]))
                
        return x
    


In [6]:
import scipy.io as sio
import numpy as np
def Get_neighs_order(rotated=0):
    neigh_orders_163842 = get_neighs_order(163842, rotated)
    neigh_orders_40962 = get_neighs_order(40962, rotated)
    neigh_orders_10242 = get_neighs_order(10242, rotated)
    neigh_orders_2562 = get_neighs_order(2562, rotated)
    neigh_orders_642 = get_neighs_order(642, rotated)
    neigh_orders_162 = get_neighs_order(162, rotated)
    neigh_orders_42 = get_neighs_order(42, rotated)
    neigh_orders_12 = get_neighs_order(12, rotated)
    
    return neigh_orders_163842, neigh_orders_40962, neigh_orders_10242,\
        neigh_orders_2562, neigh_orders_642, neigh_orders_162, neigh_orders_42, neigh_orders_12
  
def get_neighs_order(n_vertex, rotated=0):
    adj_mat_order = sio.loadmat(abspath +'/neigh_indices/adj_mat_order_'+ \
                                str(n_vertex) +'_rotated_' + str(rotated) + '.mat')
    adj_mat_order = adj_mat_order['adj_mat_order']
    neigh_orders = np.zeros((len(adj_mat_order), 7))
    neigh_orders[:,0:6] = adj_mat_order-1
    neigh_orders[:,6] = np.arange(len(adj_mat_order))
    neigh_orders = np.ravel(neigh_orders).astype(np.int64)
    
    return neigh_orders

def Get_upconv_index(rotated=0):
    
    upconv_top_index_163842, upconv_down_index_163842 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_163842_rotated_' + str(rotated) + '.mat')
    upconv_top_index_40962, upconv_down_index_40962 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_40962_rotated_' + str(rotated) + '.mat')
    upconv_top_index_10242, upconv_down_index_10242 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_10242_rotated_' + str(rotated) + '.mat')
    upconv_top_index_2562, upconv_down_index_2562 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_2562_rotated_' + str(rotated) + '.mat')
    upconv_top_index_642, upconv_down_index_642 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_642_rotated_' + str(rotated) + '.mat')
    upconv_top_index_162, upconv_down_index_162 = get_upconv_index(abspath+'/neigh_indices/adj_mat_order_162_rotated_' + str(rotated) + '.mat')
    
    #TODO: return tuples of each level
    return upconv_top_index_163842, upconv_down_index_163842, upconv_top_index_40962, upconv_down_index_40962, upconv_top_index_10242, upconv_down_index_10242,  upconv_top_index_2562, upconv_down_index_2562,  upconv_top_index_642, upconv_down_index_642, upconv_top_index_162, upconv_down_index_162


def get_upconv_index(order_path):  
    adj_mat_order = sio.loadmat(order_path)
    adj_mat_order = adj_mat_order['adj_mat_order']
    adj_mat_order = adj_mat_order -1
    nodes = len(adj_mat_order)
    next_nodes = int((len(adj_mat_order)+6)/4)
    upconv_top_index = np.zeros(next_nodes).astype(np.int64) - 1
    for i in range(next_nodes):
        upconv_top_index[i] = i * 7 + 6
    upconv_down_index = np.zeros((nodes-next_nodes) * 2).astype(np.int64) - 1
    for i in range(next_nodes, nodes):
        raw_neigh_order = adj_mat_order[i]
        parent_nodes = raw_neigh_order[raw_neigh_order < next_nodes]
        assert(len(parent_nodes) == 2)
        for j in range(2):
            parent_neigh = adj_mat_order[parent_nodes[j]]
            index = np.where(parent_neigh == i)[0][0]
            upconv_down_index[(i-next_nodes)*2 + j] = parent_nodes[j] * 7 + index
    
    return upconv_top_index, upconv_down_index


def get_upsample_order(n_vertex):
    n_last = int((n_vertex+6)/4)
    neigh_orders = get_neighs_order(abspath+'/neigh_indices/adj_mat_order_'+ str(n_vertex) +'_rotated_0.mat')
    neigh_orders = neigh_orders.reshape(n_vertex, 7)
    neigh_orders = neigh_orders[n_last:,:]
    row, col = (neigh_orders < n_last).nonzero()
    assert len(row) == (n_vertex - n_last)*2, "len(row) == (n_vertex - n_last)*2, error!"
    
    u, indices, counts = np.unique(row, return_index=True, return_counts=True)
    assert len(u) == n_vertex - n_last, "len(u) == n_vertex - n_last, error"
    assert u.min() == 0 and u.max() == n_vertex-n_last-1, "u.min() == 0 and u.max() == n_vertex-n_last-1, error"
    assert (indices == np.asarray(list(range(n_vertex - n_last))) * 2).sum() == n_vertex - n_last, "(indices == np.asarray(list(range(n_vertex - n_last))) * 2).sum() == n_vertex - n_last, error"
    assert (counts == 2).sum() == n_vertex - n_last, "(counts == 2).sum() == n_vertex - n_last, error"
    
    upsample_neighs_order = neigh_orders[row, col]
    
    return upsample_neighs_order


abspath = 'C:/Users/DELL/Desktop/kaiti/SphericalUNetPackage/sphericalunet/utils'
neigh_orders = Get_neighs_order()
neigh_orders = neigh_orders[1:]
a, b, upconv_top_index_40962, upconv_down_index_40962, upconv_top_index_10242, upconv_down_index_10242,  upconv_top_index_2562, upconv_down_index_2562,  upconv_top_index_642, upconv_down_index_642, upconv_top_index_162, upconv_down_index_162 = Get_upconv_index()

In [4]:
import scipy.io as sio 
import torch.nn as nn
class down_block(nn.Module):
    """
    downsampling block in spherical unet
    mean pooling => (conv => BN => ReLU) * 2
    
    """
    def __init__(self, conv_layer, in_ch, out_ch, neigh_orders, pool_neigh_orders, first = False):
        super(down_block, self).__init__()


#        Batch norm version
        if first:
            self.block = nn.Sequential(
                conv_layer(in_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),
                conv_layer(out_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True)
        )
            
        else:
            self.block = nn.Sequential(
                pool_layer(pool_neigh_orders, 'mean'),
                conv_layer(in_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),
                conv_layer(out_ch, out_ch, neigh_orders),
                nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),
        )


    def forward(self, x):
        # batch norm version
        x = self.block(x)
        
        return x

class upconv_layer(nn.Module):
    """
    The transposed convolution layer on icosahedron discretized sphere using 1-ring filter
    
    Input: 
        N x in_feats, tensor
    Return:
        ((Nx4)-6) x out_feats, tensor
    
    """  

    def __init__(self, in_feats, out_feats, upconv_top_index, upconv_down_index):
        super(upconv_layer, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.upconv_top_index = upconv_top_index
        self.upconv_down_index = upconv_down_index
        self.weight = nn.Linear(in_feats, 7 * out_feats)
        
    def forward(self, x):
       
        raw_nodes = x.size()[0]
        new_nodes = int(raw_nodes*4 - 6)
        x = self.weight(x)
        x = x.view(len(x) * 7, self.out_feats)
        x1 = x[self.upconv_top_index]
        assert(x1.size() == torch.Size([raw_nodes, self.out_feats]))
        x2 = x[self.upconv_down_index].view(-1, self.out_feats, 2)
        x = torch.cat((x1,torch.mean(x2, 2)), 0)
        assert(x.size() == torch.Size([new_nodes, self.out_feats]))
        return x
    
    
class up_block(nn.Module):
    """Define the upsamping block in spherica uent
    upconv => (conv => BN => ReLU) * 2
    
    Parameters:
            in_ch (int) - - input features/channels
            out_ch (int) - - output features/channels    
            neigh_orders (tensor, int)  - - conv layer's filters' neighborhood orders
            
    """    
    def __init__(self, conv_layer, in_ch, out_ch, neigh_orders, upconv_top_index, upconv_down_index):
        super(up_block, self).__init__()
        
        self.up = upconv_layer(in_ch, out_ch, upconv_top_index, upconv_down_index)
        
        # batch norm version
        self.double_conv = nn.Sequential(
             conv_layer(in_ch, out_ch, neigh_orders),
             nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
             nn.LeakyReLU(0.2, inplace=True),
             conv_layer(out_ch, out_ch, neigh_orders),
             nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False),
             nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x1, x2):
        
        x1 = self.up(x1)
        x = torch.cat((x1, x2), 1) 
        x = self.double_conv(x)

        return x
    
class Unet_40k(nn.Module):
    """Define the Spherical UNet structure

    """    
    def __init__(self, in_ch, out_ch):
        """ Initialize the Spherical UNet.

        Parameters:
            in_ch (int) - - input features/channels
            out_ch (int) - - output features/channels
        """
        super(Unet_40k, self).__init__()

        #neigh_indices_10242, neigh_indices_2562, neigh_indices_642, neigh_indices_162, neigh_indices_42 = Get_indices_order()
        #neigh_orders_10242, neigh_orders_2562, neigh_orders_642, neigh_orders_162, neigh_orders_42, neigh_orders_12 = Get_neighs_order()
        
        neigh_orders = Get_neighs_order()
        neigh_orders = neigh_orders[1:]
        a, b, upconv_top_index_40962, upconv_down_index_40962, upconv_top_index_10242, upconv_down_index_10242,  upconv_top_index_2562, upconv_down_index_2562,  upconv_top_index_642, upconv_down_index_642, upconv_top_index_162, upconv_down_index_162 = Get_upconv_index() 

        chs = [in_ch, 32, 64, 128, 256, 512]
        
        conv_layer = onering_conv_layer

        self.down1 = down_block(conv_layer, chs[0], chs[1], neigh_orders[0], None, True)
        self.down2 = down_block(conv_layer, chs[1], chs[2], neigh_orders[1], neigh_orders[0])
        self.down3 = down_block(conv_layer, chs[2], chs[3], neigh_orders[2], neigh_orders[1])
        self.down4 = down_block(conv_layer, chs[3], chs[4], neigh_orders[3], neigh_orders[2])
        self.down5 = down_block(conv_layer, chs[4], chs[5], neigh_orders[4], neigh_orders[3])
      
        self.up1 = up_block(conv_layer, chs[5], chs[4], neigh_orders[3], upconv_top_index_642, upconv_down_index_642)
        self.up2 = up_block(conv_layer, chs[4], chs[3], neigh_orders[2], upconv_top_index_2562, upconv_down_index_2562)
        self.up3 = up_block(conv_layer, chs[3], chs[2], neigh_orders[1], upconv_top_index_10242, upconv_down_index_10242)
        self.up4 = up_block(conv_layer, chs[2], chs[1], neigh_orders[0], upconv_top_index_40962, upconv_down_index_40962)
        
        self.outc = nn.Sequential(
                nn.Linear(chs[1], out_ch)
                )
                
        
    def forward(self, x):
        x2 = self.down1(x)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2) # 40962 * 32
        
        x = self.outc(x) # 40962 * 36
        return x

In [18]:
import matplotlib.pyplot as plt
import torch

################################################################
""" hyper-parameters """
# cuda = torch.device('cpu')
device = torch.device("cuda:0")
batch_size = 1
model_name = 'Unet_infant'  # 'Unet_infant', 'Unet_18', 'Unet_2ring', 'Unet_repa', 'fcn', 'SegNet', 'SegNet_max'
up_layer = 'upsample_interpolation' # 'upsample_interpolation', 'upsample_fixindex' 
in_channels = 1
out_channels = 1
learning_rate = 0.001
momentum = 0.99
weight_decay = 0.0001
fold = 1 # 1,2,3 
################################################################

data_file = 'C:/Users/DELL/Desktop/kaiti/SphericalUNet_ttl/data.npy'
label_file = 'C:/Users/DELL/Desktop/kaiti/SphericalUNet_ttl/data.npy'

# 从.npy文件中加载数据
data_array = np.load(data_file)
train_data = data_array[0:400,:]
val_data = data_array[400:,:]
print(data_array.shape,train_data.shape)
label_array = np.load(label_file)
train_label = label_array[0:400,:]
val_label = label_array[400:,:]

neighbors_path = 'C:/Users/DELL/Desktop/kaiti/Spherical_U-Net/neigh_indices/adj_mat_order_10242.mat'

# def get_neighs_order(neighbors_path):
#     adj_mat_order = sio.loadmat(neighbors_path)
#     adj_mat_order = adj_mat_order['adj_mat_order']
#     neigh_orders = np.zeros((len(adj_mat_order), 7))
#     neigh_orders[:,0:6] = adj_mat_order-1
#     neigh_orders[:,6] = np.arange(len(adj_mat_order))
#     neigh_orders = np.ravel(neigh_orders).astype(np.int64)
    
#     return neigh_orders

class BrainSphere(torch.utils.data.Dataset):

    def __init__(self, data, label):
        self.data = data
        self.labels = label
        
    def __len__(self):
        return len(self.labels)


    def __getitem__(self, idx):
        sample = {
            'data': torch.from_numpy(self.data[idx]),
            'label': torch.from_numpy(self.data[idx])
        }
        return sample



train_dataset = BrainSphere(train_data, train_label)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, pin_memory=True)

val_dataset = BrainSphere(val_data, val_label)

val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True)
# for batch in train_dataloader:
#     data, labels = batch['data'], batch['label']
#     # 在这里可以进行你的训练或其他操作
#     print(f"Batch Data Shape: {data.shape}, Batch Label Shape: {labels.shape}")

model = Unet_40k(in_ch=1, out_ch=out_channels)

print("{} paramerters in total".format(sum(x.numel() for x in model.parameters())))
model.to(device)
criterion = nn.MSELoss()
val_criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)


def train_step(data, target):
    # model.train()
    data, target = data.to(device), target.to(device)

    prediction = model(data)
    # print(prediction)
    loss = criterion(prediction, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(),prediction


# # train_dice = [0, 0, 0, 0, 0]
print('length of dataloader:', len(train_dataloader))
loss_list = []
val_loss_list = []
for epoch in range(100):
    total_loss = 0
    pre_list = []
    label_list = []
    model.train() 

    if epoch % 20 == 9:
        scheduler1.step()
    
    for batch_idx, train_data in enumerate(train_dataloader):
        data = train_data['data']
        data = torch.squeeze(data).unsqueeze(1)
        target = train_data['label']
        target = torch.tensor(target, dtype=torch.float32)
        # print(target)
        loss,pre = train_step(data, target)
        pre_list.append(pre.item())
        label_list.append(target.item())
        total_loss = total_loss + loss
    # print(total_loss/len(train_dataloader))
    loss_list.append(total_loss/len(train_dataloader))
    # print(pre_list)
    # print(label_list)

    # model.eval()  # 切换到评估模式
        
    # with torch.no_grad():  # 禁用梯度计算，因为在验证阶段我们不需要反向传播
    #     val_loss_total = 0
    #     val_pre_list = []
    #     val_label_list = []
        
    #     for batch_idy, val_dataset in enumerate(val_dataloader):
    #         val_data = val_dataset['data']
    #         val_data = torch.squeeze(val_data).unsqueeze(1)
    #         val_target = val_dataset['label']
    #         val_target = torch.tensor(val_target, dtype=torch.float32)
    #         data, target = val_data.to(device), val_target.to(device)

    #         prediction = model(data)
    #         # print(prediction)
    #         val_loss = val_criterion(prediction, target)
    #         val_loss = val_loss.item()
            


    #         val_pre_list.append(prediction.item())
    #         val_label_list.append(val_target.item())
    #         val_loss_total = val_loss_total + val_loss
        
    #     # 计算平均验证损失
    #     avg_val_loss = val_loss_total / len(val_dataloader)
    #     val_loss_list.append(val_loss_total / len(val_dataloader))
    #     print(avg_val_loss)
    #     correlation_coefficient = np.corrcoef(val_pre_list, val_label_list)[0, 1]
    #     print(f'Correlation Coefficient: {correlation_coefficient}')
        # print(val_pre_list)
        # print(val_label_list)



plt.plot(loss_list)
# val_loss_list2 = []
# for i in val_loss_list:
#     val_loss_list2.append(i.item())

plt.plot(val_loss_list)

(504, 10242) (400, 10242)
6721697 paramerters in total


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.