In [21]:
from layers import Unet_40k
import torch
import torch.nn as nn
import torchvision
import scipy.io as sio 
import numpy as np
import glob
import os
import nibabel as nib

# from utils import compute_weight
import matplotlib.pyplot as plt
# from tensorboardX import SummaryWriter
# writer = SummaryWriter('log/a')

################################################################
""" hyper-parameters """
cuda = torch.device('cuda:0')
batch_size = 1
model_name = 'Unet_infant'  # 'Unet_infant', 'Unet_18', 'Unet_2ring', 'Unet_repa', 'fcn', 'SegNet', 'SegNet_max'
up_layer = 'upsample_interpolation' # 'upsample_interpolation', 'upsample_fixindex' 
in_channels = 3
out_channels = 36
learning_rate = 0.001
momentum = 0.99
weight_decay = 0.0001
fold = 1 # 1,2,3 
################################################################


class BrainSphere(torch.utils.data.Dataset):

    def __init__(self, root):

        self.files = sorted(glob.glob(os.path.join(root, '*\surf')))
        # print(self.files)



    def __getitem__(self, index):

        file = self.files[index]
        # print(file)
        # print(os.path.join(file, 'lh.ico7.curv'))
        curv = nib.freesurfer.io.read_morph_data(os.path.join(file, 'lh.ico6.curv'))
        thickness = nib.freesurfer.io.read_morph_data(os.path.join(file, 'lh.ico6.thickness'))
        sulc = nib.freesurfer.io.read_morph_data(os.path.join(file, 'lh.ico6.sulc'))
        area = nib.freesurfer.io.read_morph_data(os.path.join(file, 'lh.ico6.area'))
        curv = torch.from_numpy(curv/np.max(curv)).unsqueeze(1)
        thickness = torch.from_numpy(thickness/np.max(thickness)).unsqueeze(1)
        sulc = torch.from_numpy(sulc/np.max(sulc)).unsqueeze(1)
        area = torch.from_numpy(area/np.max(area)).unsqueeze(1)
        # print(curv.shape, thickness.shape, sulc.shape, area.shape)
        feats = torch.cat((curv, sulc,thickness), 1)
        label = area
        return feats, label


    def __len__(self):
        return len(self.files)



train_dataset = BrainSphere('F:/New_SphericalNN/')

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
# val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)


model = Unet_40k(3, 1, abspath='C:/Users/DELL/Desktop/kaiti/SphericalUNetPackage/sphericalunet/utils')
print("{} paramerters in total".format(sum(x.numel() for x in model.parameters())))
model.cuda(cuda)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.2, patience=1, verbose=True, threshold=0.0001, threshold_mode='rel', min_lr=0.000001)


def train_step(data, target):
    model.train()
    data, target = data.cuda(cuda), target.cuda(cuda)

    prediction = model(data)
    # print(prediction)
    
    loss = criterion(prediction, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


# def compute_dice(pred, gt):

#     pred = pred.cpu().numpy()
#     gt = gt.cpu().numpy()
    
#     dice = np.zeros(36)
#     for i in range(36):
#         gt_indices = np.where(gt == i)[0]
#         pred_indices = np.where(pred == i)[0]
#         dice[i] = 2 * len(np.intersect1d(gt_indices, pred_indices))/(len(gt_indices) + len(pred_indices))
#     return dice


# def val_during_training(dataloader):
#     model.eval()

#     dice_all = np.zeros((len(dataloader),36))
#     for batch_idx, (data, target) in enumerate(dataloader):
#         data = data.squeeze()
#         target = target.squeeze()
#         data, target = data.cuda(cuda), target.cuda(cuda)
#         with torch.no_grad():
#             prediction = model(data)
            
#         prediction = prediction.max(1)[1]
#         dice_all[batch_idx,:] = compute_dice(prediction, target)

#     return dice_all

from sphericalunet.utils.vtk import read_vtk, write_vtk, resample_label
in_file = 'C:/Users/DELL/Desktop/kaiti/Spherical_U-Net/examples/left_hemisphere/40962/test1.lh.40k.vtk'
data = read_vtk(in_file)
curv_temp = data['curv']
n_vertices = 40962
curv = torch.from_numpy(data['curv'][0:n_vertices]).unsqueeze(1) # use curv data with 40k vertices
sulc = torch.from_numpy(data['sulc'][0:n_vertices]).unsqueeze(1) # use sulc data with 40k vertices


train_dice = [0, 0, 0, 0, 0]
for epoch in range(100):
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_dataloader):
        data = data.squeeze()
        target = target.squeeze()
        # print(data.shape)
        loss = train_step(data, target)
        total_loss += loss
    print(total_loss/len(train_dataloader))
    

6722145 paramerters in total
0.11331276505688827
0.1105934940179189
0.11042880231638749
0.10922851771612963
0.109218665137887
0.10852322144806385
0.10754507834712665
0.10710382728775342
0.10665958893299103
0.1064228210200866
0.10690698397656281
0.10620332846542199


KeyboardInterrupt: 