In [1]:
## Imports
from __future__ import print_function

import torch
from torch import optim
from torch.utils.data import DataLoader
import torch.nn as nn

from unet.unet_model import UNet

from utils.dataset import CrossValDataset
from utils.model_eval import dice_coefficient
from utils.data import RegressorDataset

from torch.utils.tensorboard import SummaryWriter

import numpy as np

In [2]:
## Device
if torch.cuda.is_available():
    device = 'cuda'
    print('Using GPU')
else:
    device = 'cpu'
    print('Using CPU')

Using GPU


In [3]:
## Load Model
modelpath = "D:/autopos/train-spine/unet-spine-save4_63.pth"
pretrained_model = torch.load(modelpath, map_location='cpu')

In [4]:
## Dataset
groupPath  = "D:/autopos/train-spine"
originPath = "D:/autopos/train/origin.csv"
batch_size = 8

## Training
n_channels = 1
n_classes = 14

learning_rate = 1e-3
weight_decay  = 1e-8
momentum      = 0.9

num_epochs    = 11

##Cropping
D = 350
N = 512
xc=0
yc=0
centermode = 'rand'
mu = np.array([0,0])
sigma= 50

In [5]:
## Train
fullDirList = np.array([groupPath+"/group_00",groupPath+"/group_01",groupPath+"/group_02",groupPath+"/group_03",groupPath+"/group_04"])
inds = np.arange(len(fullDirList))
# for partition in range(len(fullDirList)):
# for partition in range(1):
for partition in [4]:
    partition = partition
    trainInd = np.delete(inds,[partition])
    testInd  = inds[partition]
    trainDirs = fullDirList[trainInd]
    testDir   = [fullDirList[testInd]]
    
    trainDataset = RegressorDataset([x + "/img" for x in trainDirs], [x + "/seg" for x in trainDirs], originPath,D=D,N=N,xc=xc,yc=yc,centermode=centermode,mu=mu,sigma=sigma)
    testDataset  = RegressorDataset([x + "/img" for x in testDir],   [x + "/seg" for x in testDir],   originPath,D=D,N=N,xc=xc,yc=yc,centermode=centermode,mu=mu,sigma=sigma)
    
    trainLoader  = DataLoader(trainDataset,batch_size=batch_size)
    testLoader   = DataLoader(testDataset,batch_size=batch_size)
    
    # Define Net
    net = UNet(n_channels=n_channels,n_classes=n_classes)
    net.apply_state_dict(pretrained_model)
    net.to(device=device)
    print("Loaded pretrained model.")
    
    # Optimizer
    optimizer = optim.RMSprop(net.parameters(),lr=learning_rate,weight_decay=weight_decay,momentum=momentum)
    critereon = nn.CrossEntropyLoss()
    
    #Tensorboard
    writer = SummaryWriter('runs/crop_spine_model_partition_'+str(partition))
    
    for epoch in range(num_epochs):
        net.train()
        epoch_train_loss = 0
        for trainIndex,batch in enumerate(trainLoader,0):
            imgs    = batch['img']
            segs_gt = batch['seg']
            
            imgs    = imgs.to(   device=device,dtype=torch.float32)
            segs_gt = segs_gt.to(device=device,dtype=torch.long)
            
            segs_pr = net(imgs)
            
            loss = critereon(segs_pr,segs_gt)
            epoch_train_loss += loss.item()/len(trainLoader)
            
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(net.parameters(),0.1)
            optimizer.step()
            
            print("Train: " + str(loss.item()))
            
            writer.add_scalar('Batch Training Loss',
                        loss.item()/len(trainLoader),
                        epoch*(len(trainDataset))+trainIndex)
            
        writer.add_scalar('Training Loss',
                        epoch_train_loss,
                        epoch)
        
        net.eval()
        epoch_test_loss = 0
        for testIndex,batch in enumerate(testLoader,0):
            imgs    = batch['img']
            segs_gt = batch['seg']
            
            imgs    = imgs.to(   device=device,dtype=torch.float32)
            segs_gt = segs_gt.to(device=device,dtype=torch.long)
            
            segs_pr = net(imgs)
            
            loss = critereon(segs_pr,segs_gt)
            epoch_test_loss += loss.item()/len(testLoader)
            
            print("Test: " + str(loss.item()))
            
            writer.add_scalar('Batch Testing Loss',
                        loss.item()/len(trainLoader),
                        epoch*(len(testDataset))+testIndex)
            
        writer.add_scalar('Validation Loss',
                        epoch_test_loss,
                        epoch)
        
        savefile = "D:\\autopos\\nets\\unet-crop-spine-save" + str(partition) + "_" + str(epoch) + ".pth"
        torch.save(net.state_dict(),savefile)


Loaded pretrained model.
Train: 11.356849670410156
Train: 9.258655548095703
Train: 7.381056308746338
Train: 3.9053854942321777
Train: 2.7838134765625
Train: 3.291076183319092
Train: 1.5488122701644897
Train: 1.743996262550354
Train: 1.9209836721420288
Train: 1.636531949043274
Train: 1.7179573774337769
Train: 1.7365837097167969
Train: 1.2615821361541748
Train: 1.3345918655395508
Train: 1.1208629608154297
Train: 1.006991982460022
Train: 1.2294667959213257
Train: 1.0680245161056519
Train: 1.0367028713226318
Train: 0.9852482080459595
Train: 0.9516698718070984
Train: 0.7056750655174255
Train: 0.788703441619873
Train: 0.6976118087768555
Train: 0.8603591322898865
Train: 0.8013346195220947
Train: 0.6728008985519409
Train: 1.0067591667175293
Train: 0.795354962348938
Train: 0.6232860684394836
Train: 0.7639835476875305
Train: 0.9289135932922363
Train: 0.9751973152160645
Train: 1.2001515626907349
Train: 0.7630053758621216
Train: 0.6630017757415771
Train: 0.5673465728759766
Train: 0.581792294979095

Train: 0.2717452943325043
Train: 0.21141162514686584
Train: 0.23146982491016388
Train: 0.28670015931129456
Train: 0.3218511641025543
Train: 0.283911794424057
Train: 0.203388512134552
Train: 0.28819724917411804
Train: 0.42934027314186096
Train: 0.5849730968475342
Train: 0.8297816514968872
Train: 0.4165876805782318
Train: 0.28989484906196594
Train: 0.29950347542762756
Train: 0.3055998384952545
Train: 0.26330938935279846
Train: 0.43730542063713074
Train: 0.3011663854122162
Train: 0.5530257225036621
Train: 0.2749188542366028
Train: 0.25265049934387207
Train: 0.29034072160720825
Train: 0.2952401638031006
Train: 0.3760751485824585
Train: 0.6912100315093994
Train: 0.41421371698379517
Train: 0.21759995818138123
Train: 0.3183695375919342
Train: 0.2549741268157959
Train: 0.4106783866882324
Train: 0.5003141760826111
Train: 0.5510371327400208
Train: 0.2114083170890808
Train: 0.24564331769943237
Train: 0.33430808782577515
Train: 0.44059380888938904
Train: 0.32186487317085266
Train: 0.46438741683959

Train: 0.2346910834312439
Train: 0.2164393663406372
Train: 0.31095412373542786
Train: 0.26412808895111084
Train: 0.5909201502799988
Train: 0.2506982088088989
Train: 0.22055989503860474
Train: 0.19641856849193573
Train: 0.31732651591300964
Train: 0.19811657071113586
Train: 0.244875967502594
Train: 0.24375571310520172
Test: 0.15612287819385529
Test: 0.17222779989242554
Test: 0.44653409719467163
Test: 0.21536004543304443
Test: 0.5501524209976196
Test: 0.6165213584899902
Test: 0.2839556038379669
Test: 0.3280245363712311
Test: 0.24229615926742554
Test: 0.33129405975341797
Test: 0.5012001991271973
Test: 0.5374135375022888
Test: 0.44984981417655945
Test: 0.2713286578655243
Test: 0.29024258255958557
Test: 0.19547639787197113
Test: 0.4220599830150604
Test: 0.15664681792259216
Test: 0.2143891602754593
Test: 0.5034933090209961
Test: 0.44113394618034363
Test: 0.5193061232566833
Test: 0.2477027177810669
Test: 0.3187509775161743
Test: 0.5194794535636902
Test: 0.4753905236721039
Train: 0.131123498082

Train: 0.3688276708126068
Train: 0.23829174041748047
Train: 0.20655813813209534
Train: 0.35842660069465637
Train: 0.2685759365558624
Train: 0.29212093353271484
Train: 0.2253243774175644
Train: 0.1994466334581375
Train: 0.3071543574333191
Train: 0.29725682735443115
Train: 0.28167131543159485
Train: 0.2142210453748703
Train: 0.23704266548156738
Train: 0.22139319777488708
Train: 0.22243140637874603
Train: 0.17864251136779785
Train: 0.21490994095802307
Train: 0.20234207808971405
Train: 0.11182643473148346
Train: 0.32857421040534973
Train: 0.27234745025634766
Train: 0.2525503635406494
Train: 0.3113217353820801
Train: 0.38862115144729614
Train: 0.30549925565719604
Train: 0.3347214162349701
Train: 0.3496522903442383
Train: 0.13815343379974365
Train: 0.1546308398246765
Train: 0.16237063705921173
Train: 0.34764158725738525
Train: 0.19550958275794983
Train: 0.19047069549560547
Train: 0.12670299410820007
Train: 0.12658661603927612
Train: 0.11944962292909622
Train: 0.41517478227615356
Train: 0.419

Train: 0.08900658786296844
Train: 0.10849607735872269
Train: 0.21299059689044952
Train: 0.32924988865852356
Train: 0.36756327748298645
Train: 0.6209039688110352
Train: 0.3072895407676697
Train: 0.12130176275968552
Train: 0.1977616250514984
Train: 0.22370412945747375
Train: 0.16813074052333832
Train: 0.3145946264266968
Train: 0.26660823822021484
Train: 0.20324529707431793
Train: 0.10847138613462448
Train: 0.18615779280662537
Train: 0.23929138481616974
Train: 0.22022634744644165
Train: 0.1635468751192093
Train: 0.36626529693603516
Train: 0.35949259996414185
Train: 0.11485014110803604
Train: 0.1650625318288803
Train: 0.13518580794334412
Train: 0.28690534830093384
Train: 0.32823461294174194
Train: 0.3784538209438324
Train: 0.10837553441524506
Train: 0.09981029480695724
Train: 0.278298944234848
Train: 0.35546329617500305
Train: 0.18401852250099182
Train: 0.3525199890136719
Train: 0.3838048279285431
Train: 0.12143690884113312
Train: 0.11091136932373047
Train: 0.19598616659641266
Train: 0.218