In [1]:
import os
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from skimage import io, transform
from torchvision import transforms, utils
import torch
import numpy as np
import nibabel as nib
from random import randint
from PIL import Image
import torch.optim as optim
import time
import QuickNAT as QN
import torch.nn as nn

In [2]:
gpu_id = 1
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

device = torch.device('cuda')

In [3]:
model = QN.QuickNAT(1,64,256)
model_params = list(model.parameters())
nb_param=0
for param in model.parameters():
    nb_param+=np.prod(list(param.data.size()))
print(nb_param)
model = model.to(device)

3308688


In [4]:
class TrainDataset(Dataset):
    """Training dataset with mask image mapping to classes"""
    def __init__(self, T1a_dir, parc5a_dir, transform=None):
        """
        Args:
            T1a_dir (string): Directory with T1w image in axial plane
            transform (callable): Optional transform to be applied on a sample
            parc5a_dir (string): Directory with parcellation scale 5 in axial plane
        """
        self.T1a_dir = T1a_dir
        self.transform = transform
        self.parc5a_dir = parc5a_dir
        
    def __len__(self):
        T1a_list = os.listdir(self.T1a_dir)
        return len(T1a_list)
    
    
    def __getitem__(self, idx):
        T1a_list = os.listdir(T1a_dir)
        parc5a_list = os.listdir(parc5a_dir)
        
        T1a_str = T1a_list[idx]
        
        T1a_arr = io.imread(os.path.join(T1a_dir, T1a_str))
        T1a_tensor = torch.from_numpy(T1a_arr)
        
        compose_T1 = transforms.Compose([transforms.ToPILImage(), 
                                         transforms.Resize((128,128),interpolation=Image.NEAREST),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        T1a_tensor = torch.unsqueeze(T1a_tensor, dim = 0)
        T1a_tensor = compose_T1(T1a_tensor)
              
        parc5a_str = parc5a_list[idx]
    
        parc5a_arr = io.imread(os.path.join(parc5a_dir, parc5a_str))
        parc5a_tensor = torch.from_numpy(parc5a_arr)
        
        compose = transforms.Compose([transforms.ToPILImage(),
                                      transforms.Resize((128,128),interpolation=Image.NEAREST), 
                                      transforms.ToTensor()])
        
        parc5a_tensor = torch.unsqueeze(parc5a_tensor, dim = 0)
        parc5a_tensor = compose(parc5a_tensor)
        parc5a_tensor = parc5a_tensor.squeeze()
        
        parc5a_tensor = torch.round(parc5a_tensor / 0.0039).byte()
      
        sample = {'T1a':T1a_tensor, 'parc5a':parc5a_tensor}
        
        if self.transform:
            T1a = self.transform(T1a_tensor)
            sample = {'T1a':T1a, 'parc5a':parc5a}
            
        return sample

In [None]:
for iteration in range(5):
    start=time.time()
    for sub_idx in range(330):

        T1a_dir = '/home/xiaoyu/MRIdata/T1w/axial/sub{}'.format(sub_idx)
   
        parc5a_dir = '/home/xiaoyu/MRIdata/parc_5/axial/sub{}'.format(sub_idx)
   
        T1a_list = os.listdir(T1a_dir)
  
        parc5a_list = os.listdir(parc5a_dir)

    
        if sub_idx == 0: # set sub0 as test set.
            print('\nT1w Axial slices num:',len(T1a_list))
            print('\nParc5 Axial slices num:',len(parc5a_list))

            continue
        
        print('\nSubject num:',sub_idx)   
        train_data = TrainDataset(T1a_dir=T1a_dir, parc5a_dir = parc5a_dir)
        dataloader = DataLoader(train_data, batch_size = 5, shuffle = True, num_workers = 4)
    
        criterion = nn.NLLLoss()
        optimizer = optim.Adam(model.parameters() ,lr=0.001)
    
        for epoch in range(0,50):
   
    
            # define the running loss
            running_loss = 0
            running_error = 0
            num_batches=0
      
            for i_batch, sample_batched in enumerate(dataloader):
        
                optimizer.zero_grad()
        
                #get the inputs
                inputs, labels = sample_batched['T1a'], sample_batched['parc5a']
        
                inputs = inputs.to(device)
                labels = labels.to(device)

                inputs.requires_grad_()
        
                #forward + backward +optimize
                scores = model(inputs)
          
                # Define the loss
                loss = criterion(scores, labels.long()) 
                loss.backward()
                optimizer.step()
        
                # compute and accumulate stats
                running_loss += loss.detach().item()
       
                num_batches+=1 
    
            # AVERAGE STATS THEN DISPLAY    
            total_loss = running_loss/num_batches
   
            elapsed = (time.time()-start)/60
        
            print('epoch=',epoch, '\t time=', elapsed,'min', '\t loss=', total_loss )

       
        print('Finish Training')
    print(iteration,'Iteration')


T1w Axial slices num: 182

Parc5 Axial slices num: 182

Subject num: 1
epoch= 0 	 time= 0.06885821024576823 min 	 loss= 2.4378412671991296
epoch= 1 	 time= 0.13643916050593058 min 	 loss= 0.8647952885241121
epoch= 2 	 time= 0.20516642729441326 min 	 loss= 0.8289284206725456
epoch= 3 	 time= 0.27322210868199664 min 	 loss= 0.840316653453015
epoch= 4 	 time= 0.3413116176923116 min 	 loss= 0.8082802646063469
epoch= 5 	 time= 0.40866774320602417 min 	 loss= 0.8084945477343894
epoch= 6 	 time= 0.47652670939763386 min 	 loss= 0.7954363613515287
epoch= 7 	 time= 0.5434702833493551 min 	 loss= 0.8035291248076671
epoch= 8 	 time= 0.6114280104637146 min 	 loss= 0.7874052605113467
epoch= 9 	 time= 0.6788889765739441 min 	 loss= 0.7838964176339072
epoch= 10 	 time= 0.7499068101247152 min 	 loss= 0.7663572976315344
epoch= 11 	 time= 0.8171610832214355 min 	 loss= 0.7838197213572424
epoch= 12 	 time= 0.884792160987854 min 	 loss= 0.7611374107969774
epoch= 13 	 time= 0.9527156829833985 min 	 loss= 0

epoch= 21 	 time= 12.28466215133667 min 	 loss= 0.5081463485383555
epoch= 22 	 time= 12.415239767233532 min 	 loss= 0.49478681826007526
epoch= 23 	 time= 12.546603012084962 min 	 loss= 0.49847152899648695
epoch= 24 	 time= 12.68041519721349 min 	 loss= 0.5035154868068324
epoch= 25 	 time= 12.818085861206054 min 	 loss= 0.4916522890530728
epoch= 26 	 time= 12.948063655694325 min 	 loss= 0.49361430995820743
epoch= 27 	 time= 13.079297816753387 min 	 loss= 0.4923572196448977
epoch= 28 	 time= 13.211278768380483 min 	 loss= 0.5059434686963623
epoch= 29 	 time= 13.34273946682612 min 	 loss= 0.4933133121277835
epoch= 30 	 time= 13.482443992296854 min 	 loss= 0.4853751292905292
epoch= 31 	 time= 13.614001858234406 min 	 loss= 0.4942911488784326
epoch= 32 	 time= 13.745137945810955 min 	 loss= 0.48621720478341385
epoch= 33 	 time= 13.875535893440247 min 	 loss= 0.48234191838953944
epoch= 34 	 time= 14.005155499776205 min 	 loss= 0.48982756202285355
epoch= 35 	 time= 14.143117594718934 min 	 lo