In [1]:
## Imports
from __future__ import print_function

import torch
from torch import optim
from torch.utils.data import DataLoader
import torch.nn as nn

from unet.unet_model import UNet

from utils.dataset import CrossValDataset
from utils.model_eval import dice_coefficient
from utils.data import RegressorDataset

from torch.utils.tensorboard import SummaryWriter

import numpy as np

In [2]:
## Device
if torch.cuda.is_available():
    device = 'cuda'
    print('Using GPU')
else:
    device = 'cpu'
    print('Using CPU')

Using GPU


In [3]:
## Load Model
modelpath = "D:/autopos/train/unet-save2_13.pth"
pretrained_model = torch.load(modelpath, map_location='cpu')

In [4]:
## Dataset
groupPath  = "D:/autopos/train"
originPath = "D:/autopos/train/origin.csv"
batch_size = 8

## Training
n_channels = 1
n_classes = 13

learning_rate = 1e-3
weight_decay  = 1e-8
momentum      = 0.9

num_epochs    = 11

##Cropping
D = 350
N = 512
xc=0
yc=0
centermode = 'rand'
mu = np.array([0,0])
sigma= 50

In [5]:
## Train
fullDirList = np.array([groupPath+"/group_00",groupPath+"/group_01",groupPath+"/group_02",groupPath+"/group_03",groupPath+"/group_04"])
inds = np.arange(len(fullDirList))
# for partition in range(len(fullDirList)):
for partition in range(4):
    partition = partition+1
    trainInd = np.delete(inds,[partition])
    testInd  = inds[partition]
    trainDirs = fullDirList[trainInd]
    testDir   = [fullDirList[testInd]]
    
    trainDataset = RegressorDataset([x + "/img" for x in trainDirs], [x + "/seg" for x in trainDirs], originPath,D=D,N=N,xc=xc,yc=yc,centermode=centermode,mu=mu,sigma=sigma)
    testDataset  = RegressorDataset([x + "/img" for x in testDir],   [x + "/seg" for x in testDir],   originPath,D=D,N=N,xc=xc,yc=yc,centermode=centermode,mu=mu,sigma=sigma)
    
    trainLoader  = DataLoader(trainDataset,batch_size=batch_size)
    testLoader   = DataLoader(testDataset,batch_size=batch_size)
    
    # Define Net
    net = UNet(n_channels=n_channels,n_classes=n_classes)
    net.apply_state_dict(pretrained_model)
    net.to(device=device)
    print("Loaded pretrained model.")
    
    # Optimizer
    optimizer = optim.RMSprop(net.parameters(),lr=learning_rate,weight_decay=weight_decay,momentum=momentum)
    critereon = nn.CrossEntropyLoss()
    
    #Tensorboard
    writer = SummaryWriter('runs/crop_model_partition_'+str(partition))
    
    for epoch in range(num_epochs):
        net.train()
        epoch_train_loss = 0
        for trainIndex,batch in enumerate(trainLoader,0):
            imgs    = batch['img']
            segs_gt = batch['seg']
            
            imgs    = imgs.to(   device=device,dtype=torch.float32)
            segs_gt = segs_gt.to(device=device,dtype=torch.long)
            
            segs_pr = net(imgs)
            
            loss = critereon(segs_pr,segs_gt)
            epoch_train_loss += loss.item()/len(trainLoader)
            
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(net.parameters(),0.1)
            optimizer.step()
            
            print("Train: " + str(loss.item()))
            
        writer.add_scalar('Training Loss',
                        epoch_train_loss,
                        epoch)
        
        net.eval()
        epoch_test_loss = 0
        for testIndex,batch in enumerate(testLoader,0):
            imgs    = batch['img']
            segs_gt = batch['seg']
            
            imgs    = imgs.to(   device=device,dtype=torch.float32)
            segs_gt = segs_gt.to(device=device,dtype=torch.long)
            
            segs_pr = net(imgs)
            
            loss = critereon(segs_pr,segs_gt)
            epoch_test_loss += loss.item()/len(testLoader)
            
            print("Test: " + str(loss.item()))
            
        writer.add_scalar('Validation Loss',
                        epoch_test_loss,
                        epoch)
        
        savefile = "D:\\autopos\\nets\\unet-crop-save" + str(partition) + "_" + str(epoch) + ".pth"
        torch.save(net.state_dict(),savefile)


Loaded pretrained model.
Train: 9.246841430664062
Train: 5.990595817565918
Train: 6.013113498687744
Train: 2.3448214530944824
Train: 1.25921630859375
Train: 1.145085334777832
Train: 0.8413614630699158
Train: 2.163578748703003
Train: 0.7918638586997986
Train: 0.8720099925994873
Train: 0.8886170387268066
Train: 1.193792462348938
Test: 3564.714599609375
Test: 3301.105712890625
Test: 4175.58056640625
Train: 0.6596277952194214
Train: 0.8164217472076416
Train: 0.6506467461585999
Train: 0.6111125946044922
Train: 0.5059539079666138
Train: 0.47596341371536255
Train: 0.43259239196777344
Train: 0.3946937918663025
Train: 0.3539423942565918
Train: 0.3193109333515167
Train: 0.3999567925930023
Train: 0.7124047875404358
Test: 119.63856506347656
Test: 106.7889404296875
Test: 114.81185150146484
Train: 0.3193693161010742
Train: 0.47957828640937805
Train: 0.37186485528945923
Train: 0.3694743514060974
Train: 0.37247592210769653
Train: 0.3057083189487457
Train: 0.2823525667190552
Train: 0.28372007608413696


Test: 0.1484951227903366
Test: 0.1422760933637619
Test: 0.15098246932029724
Train: 0.11848039925098419
Train: 0.2788463830947876
Train: 0.15021084249019623
Train: 0.20600685477256775
Train: 0.14274145662784576
Train: 0.1428181529045105
Train: 0.14259089529514313
Train: 0.13383030891418457
Train: 0.1208890751004219
Train: 0.11885460466146469
Train: 0.15150189399719238
Train: 0.4146115481853485
Test: 0.14137734472751617
Test: 0.14399641752243042
Test: 0.1673661172389984
Loaded pretrained model.
Train: 9.223766326904297
Train: 6.0139641761779785
Train: 6.2979044914245605
Train: 1.765523076057434
Train: 1.5792012214660645
Train: 1.1064223051071167
Train: 0.8963019251823425
Train: 1.094346523284912
Train: 0.9170835018157959
Train: 0.7550961971282959
Train: 0.7665002942085266
Train: 1.1747972965240479
Test: 561.0249633789062
Test: 579.0541381835938
Test: 579.1919555664062
Train: 0.6431805491447449
Train: 0.8530577421188354
Train: 0.6405598521232605
Train: 0.5271618366241455
Train: 0.44592359

Train: 0.13578687608242035
Train: 0.11698294430971146
Train: 0.10154145210981369
Test: 0.11780941486358643
Test: 0.16346503794193268
Test: 0.3792361915111542
Train: 0.10453977435827255
Train: 0.3196810185909271
Train: 0.1335659623146057
Train: 0.18451890349388123
Train: 0.13439741730690002
Train: 0.1261638104915619
Train: 0.11614283174276352
Train: 0.12089095264673233
Train: 0.15364034473896027
Train: 0.1283969283103943
Train: 0.11083647608757019
Train: 0.09664690494537354
Test: 0.13288642466068268
Test: 0.19233573973178864
Test: 0.4673595726490021
Train: 0.11079910397529602
Train: 0.3153994083404541
Train: 0.1381772756576538
Train: 0.1689322590827942
Train: 0.1262952983379364
Train: 0.11752308160066605
Train: 0.10625173151493073
Train: 0.11826445162296295
Train: 0.14561405777931213
Train: 0.10931752622127533
Train: 0.10966067016124725
Train: 0.09603403508663177
Test: 0.11289454996585846
Test: 0.1537202000617981
Test: 0.40416219830513
