In [1]:
import sys
sys.path.append('./models/')
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from data_loader import Dataset
import models.unet_normals as unet
from tensorboardX import SummaryWriter
# import OpenEXR, Imath
import imageio
from torchvision.utils import make_grid


### Setup Options
Set the various parameters:
- dataroot: The folder where the training data is stored
- file_list: List of filenames of images for training
- batchSize: Batch size for model
- shuffle: If true, will shuffle the dataset
- phase: If 'train', then it's in training mode.
- num_epochs: Number of epochs to train the model for
- imsize: Dimensions of the image (square)
- num_classes: Num of classes in the output
- gpu: Which GPU device to use
- logs_path: The path where the log files (tensorboard) will be saved.

In [2]:
class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.batchSize = 16
        self.shuffle = False
        self.phase = 'train'
        self.num_epochs = 500
        self.imsize = (288,512)
        self.num_classes = int(2)
        self.gpu = '1'
        self.logs_path = 'logs/exp3'
        self.use_pretrained = False

opt = OPT()

### Setup logging and dataloaders

In [3]:
###################### Options #############################
phase = opt.phase
device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")

###################### TensorBoardX #############################
if os.path.exists(opt.logs_path):
    raise Exception('The folder \"{}\" already exists! Define a new log path or delete old contents.'.format(opt.logs_path))
    
writer = SummaryWriter(opt.logs_path, comment='create-graph')
graph_created = False

###################### DataLoader #############################
dataloader = Dataset(opt)


### Create the model
We use a UNet model. The last few layers of this model are modified to return a 3 channel image, containing the x,y,z values of surface normal vectors.

In [4]:
###################### ModelBuilder #############################
model = unet.Unet(num_classes=opt.num_classes)
criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

# Load weights from checkpoint
if (opt.use_pretrained == True):
    checkpoint_path = 'logs/exp1/checkpoints/checkpoint-epoch_405.pth'
    model.load_state_dict(torch.load(checkpoint_path))

model = model.to(device)
model.train()

###################### Setup Optimazation #############################
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

###################### Loss fuction #############################
def flatten_logits(logits, number_of_classes):
    """Flattens the logits batch except for the logits dimension"""
    logits_permuted = logits.permute(0, 2, 3, 1)
    logits_permuted_cont = logits_permuted.contiguous()
    logits_flatten = logits_permuted_cont.view(-1, number_of_classes)
    return logits_flatten

def flatten_annotations(annotations):
    return annotations.view(-1)

def get_valid_annotations_index(flatten_annotations, mask_out_value=255):
    return torch.squeeze( torch.nonzero((flatten_annotations != mask_out_value )), 1)

### Train the model


In [5]:
###################### Train Model #############################
# Calculate total iter_num
total_iter_num = 0

for epoch in range(opt.num_epochs):#411, 411+opt.num_epochs):
    print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    running_loss = 0.0
    
    
    # Iterate over data.
    for i in range(int(dataloader.size()/opt.batchSize)):
        total_iter_num += 1
        
        # Get data
        inputs, labels =  dataloader.get_batch()
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # We need to flatten annotations and logits to apply index of valid annotations.
        anno_flatten = flatten_annotations(labels)
        index = get_valid_annotations_index(anno_flatten, mask_out_value=255)
        anno_flatten_valid = torch.index_select(anno_flatten, 0, index)
        
        ## Create Graph ##
        if graph_created == False:
            graph_created = True
            writer.add_graph(model, inputs, verbose=False)
        
        # Forward Prop
        optimizer.zero_grad()
        torch.set_grad_enabled(True)
        logits = model(inputs)
        
        # Calculate Loss
        logits_flatten = flatten_logits(logits, number_of_classes=opt.num_classes)
        logits_flatten_valid = torch.index_select(logits_flatten, 0, index)
        loss = criterion(logits_flatten_valid, anno_flatten_valid)
        
        # Backward Prop
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item()
        writer.add_scalar('loss', loss.item(), total_iter_num)

        
        # Print image every N epochs
        nTestInterval = 1
        if (epoch % nTestInterval) == 0:
            img_tensor = inputs[:3].clone().cpu()
            
            output_tensor = torch.unsqueeze(torch.max(logits[:3], 1)[1].detach().cpu().float(), 1)
            output_tensor = torch.cat((output_tensor, output_tensor, output_tensor), 1)
            
            label_tensor = labels[:3].detach().cpu().float()
            label_tensor = torch.cat((label_tensor, label_tensor, label_tensor), 1)
            
            images = []
            for img, output, label in zip(img_tensor, output_tensor, label_tensor):
                images.append(img)
                images.append(output)
                images.append(label)

            grid_image = make_grid(images, 3, normalize=True, scale_each=False )
            writer.add_image('Train', grid_image, epoch)
        
        if (i % 2 == 0):
            print('Epoch{} Batch{} Loss: {:.4f}'.format(epoch, i, loss.item()))

    epoch_loss = running_loss / (dataloader.size()/opt.batchSize)
    writer.add_scalar('epoch_loss', epoch_loss, epoch)
    print('{} Loss: {:.4f}'.format(phase, epoch_loss))
    
    # Save the model checkpoint
    directory = opt.logs_path+'/checkpoints/'
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    if (epoch % 5 == 0):
        filename = opt.logs_path + '/checkpoints/checkpoint-epoch_{}.pth'.format(epoch,i)
        torch.save(model.state_dict(), filename)
        

  

Epoch 0/499
----------
Epoch0 Batch0 Loss: 1645235.0000
Epoch0 Batch2 Loss: 1593438.3750
Epoch0 Batch4 Loss: 1528245.8750
train Loss: 1510709.6000
Epoch 1/499
----------
Epoch1 Batch0 Loss: 1451585.0000
Epoch1 Batch2 Loss: 1367963.8750
Epoch1 Batch4 Loss: 1065669.7500
train Loss: 1163361.9300
Epoch 2/499
----------
Epoch2 Batch0 Loss: 1144395.8750
Epoch2 Batch2 Loss: 644112.4375
Epoch2 Batch4 Loss: 949840.5000
train Loss: 879449.6300
Epoch 3/499
----------
Epoch3 Batch0 Loss: 802621.6875
Epoch3 Batch2 Loss: 829533.9375
Epoch3 Batch4 Loss: 838463.4375
train Loss: 760407.0700
Epoch 4/499
----------
Epoch4 Batch0 Loss: 606759.1875
Epoch4 Batch2 Loss: 615435.8125
Epoch4 Batch4 Loss: 941206.5625
train Loss: 644992.4000
Epoch 5/499
----------
Epoch5 Batch0 Loss: 602447.8750
Epoch5 Batch2 Loss: 677729.1250
Epoch5 Batch4 Loss: 739058.6250
train Loss: 623530.7100
Epoch 6/499
----------
Epoch6 Batch0 Loss: 574438.6875
Epoch6 Batch2 Loss: 675520.1250
Epoch6 Batch4 Loss: 662643.3125
train Loss: 60

Epoch 56/499
----------
Epoch56 Batch0 Loss: 430143.0625
Epoch56 Batch2 Loss: 519114.4688
Epoch56 Batch4 Loss: 532301.8750
train Loss: 469853.9300
Epoch 57/499
----------
Epoch57 Batch0 Loss: 469449.1250
Epoch57 Batch2 Loss: 472018.7500
Epoch57 Batch4 Loss: 480425.7500
train Loss: 486129.3550
Epoch 58/499
----------
Epoch58 Batch0 Loss: 428577.2812
Epoch58 Batch2 Loss: 509430.6875
Epoch58 Batch4 Loss: 453721.2188
train Loss: 500055.1750
Epoch 59/499
----------
Epoch59 Batch0 Loss: 431414.8438
Epoch59 Batch2 Loss: 458687.5312
Epoch59 Batch4 Loss: 409134.8438
train Loss: 455027.5200
Epoch 60/499
----------
Epoch60 Batch0 Loss: 562472.4375
Epoch60 Batch2 Loss: 400434.9062
Epoch60 Batch4 Loss: 468270.5000
train Loss: 459010.7000
Epoch 61/499
----------
Epoch61 Batch0 Loss: 535811.0625
Epoch61 Batch2 Loss: 409031.2188
Epoch61 Batch4 Loss: 433123.3750
train Loss: 441603.0600
Epoch 62/499
----------
Epoch62 Batch0 Loss: 591772.7500
Epoch62 Batch2 Loss: 436178.2500
Epoch62 Batch4 Loss: 426187.

Epoch111 Batch4 Loss: 316915.6875
train Loss: 345041.1950
Epoch 112/499
----------
Epoch112 Batch0 Loss: 472875.7812
Epoch112 Batch2 Loss: 347229.0938
Epoch112 Batch4 Loss: 321459.2188
train Loss: 340233.6850
Epoch 113/499
----------
Epoch113 Batch0 Loss: 505697.8438
Epoch113 Batch2 Loss: 344408.6250
Epoch113 Batch4 Loss: 378801.6875
train Loss: 343812.2900
Epoch 114/499
----------
Epoch114 Batch0 Loss: 426714.0625
Epoch114 Batch2 Loss: 321238.3750
Epoch114 Batch4 Loss: 372100.2188
train Loss: 346180.8050
Epoch 115/499
----------
Epoch115 Batch0 Loss: 368272.0000
Epoch115 Batch2 Loss: 346912.6875
Epoch115 Batch4 Loss: 367658.6562
train Loss: 355238.3200
Epoch 116/499
----------
Epoch116 Batch0 Loss: 343285.6875
Epoch116 Batch2 Loss: 342101.7812
Epoch116 Batch4 Loss: 357392.6250
train Loss: 359150.5850
Epoch 117/499
----------
Epoch117 Batch0 Loss: 339125.0938
Epoch117 Batch2 Loss: 315187.7188
Epoch117 Batch4 Loss: 314113.1562
train Loss: 350381.3800
Epoch 118/499
----------
Epoch118 Ba

Epoch 166/499
----------
Epoch166 Batch0 Loss: 278870.0625
Epoch166 Batch2 Loss: 267109.8125
Epoch166 Batch4 Loss: 290908.9688
train Loss: 278686.2075
Epoch 167/499
----------
Epoch167 Batch0 Loss: 251766.9062
Epoch167 Batch2 Loss: 259148.0000
Epoch167 Batch4 Loss: 231507.3281
train Loss: 274274.0175
Epoch 168/499
----------
Epoch168 Batch0 Loss: 239709.8906
Epoch168 Batch2 Loss: 318648.6875
Epoch168 Batch4 Loss: 235129.8438
train Loss: 273559.0225
Epoch 169/499
----------
Epoch169 Batch0 Loss: 245474.7656
Epoch169 Batch2 Loss: 353323.4062
Epoch169 Batch4 Loss: 229998.7969
train Loss: 271593.9200
Epoch 170/499
----------
Epoch170 Batch0 Loss: 231585.2031
Epoch170 Batch2 Loss: 357853.4688
Epoch170 Batch4 Loss: 248581.0312
train Loss: 263268.3175
Epoch 171/499
----------
Epoch171 Batch0 Loss: 235932.7656
Epoch171 Batch2 Loss: 372118.3438
Epoch171 Batch4 Loss: 253589.0312
train Loss: 250303.6725
Epoch 172/499
----------
Epoch172 Batch0 Loss: 302420.0625
Epoch172 Batch2 Loss: 319280.3750
E

Epoch220 Batch2 Loss: 298526.3750
Epoch220 Batch4 Loss: 186855.7500
train Loss: 210703.5725
Epoch 221/499
----------
Epoch221 Batch0 Loss: 194687.4219
Epoch221 Batch2 Loss: 309280.2500
Epoch221 Batch4 Loss: 215474.2812
train Loss: 203731.6350
Epoch 222/499
----------
Epoch222 Batch0 Loss: 249815.4219
Epoch222 Batch2 Loss: 275040.8125
Epoch222 Batch4 Loss: 219250.0625
train Loss: 215109.7625
Epoch 223/499
----------
Epoch223 Batch0 Loss: 253616.2188
Epoch223 Batch2 Loss: 247948.5312
Epoch223 Batch4 Loss: 232370.0156
train Loss: 222481.4200
Epoch 224/499
----------
Epoch224 Batch0 Loss: 286947.1250
Epoch224 Batch2 Loss: 227853.0781
Epoch224 Batch4 Loss: 236560.4844
train Loss: 236859.0000
Epoch 225/499
----------
Epoch225 Batch0 Loss: 247776.0938
Epoch225 Batch2 Loss: 211202.9375
Epoch225 Batch4 Loss: 230734.6406
train Loss: 237164.1500
Epoch 226/499
----------
Epoch226 Batch0 Loss: 171803.0625
Epoch226 Batch2 Loss: 209385.3594
Epoch226 Batch4 Loss: 264279.2188
train Loss: 225807.0325
Ep

train Loss: 184484.9525
Epoch 275/499
----------
Epoch275 Batch0 Loss: 186764.5156
Epoch275 Batch2 Loss: 166870.7188
Epoch275 Batch4 Loss: 185522.9062
train Loss: 187920.5075
Epoch 276/499
----------
Epoch276 Batch0 Loss: 163407.1875
Epoch276 Batch2 Loss: 161622.1719
Epoch276 Batch4 Loss: 236668.9688
train Loss: 193599.8800
Epoch 277/499
----------
Epoch277 Batch0 Loss: 134582.6875
Epoch277 Batch2 Loss: 168839.1562
Epoch277 Batch4 Loss: 258975.0781
train Loss: 192382.1300
Epoch 278/499
----------
Epoch278 Batch0 Loss: 139123.9062
Epoch278 Batch2 Loss: 162627.0625
Epoch278 Batch4 Loss: 272600.0938
train Loss: 185106.9325
Epoch 279/499
----------
Epoch279 Batch0 Loss: 160527.3438
Epoch279 Batch2 Loss: 170276.9531
Epoch279 Batch4 Loss: 271061.3438
train Loss: 184817.2750
Epoch 280/499
----------
Epoch280 Batch0 Loss: 177427.5156
Epoch280 Batch2 Loss: 217738.7812
Epoch280 Batch4 Loss: 222018.2031
train Loss: 180054.2600
Epoch 281/499
----------
Epoch281 Batch0 Loss: 171885.3125
Epoch281 Ba

Epoch329 Batch0 Loss: 112798.6719
Epoch329 Batch2 Loss: 118634.5000
Epoch329 Batch4 Loss: 204486.7344
train Loss: 133181.2275
Epoch 330/499
----------
Epoch330 Batch0 Loss: 125031.6094
Epoch330 Batch2 Loss: 165740.3281
Epoch330 Batch4 Loss: 163157.0000
train Loss: 134003.5700
Epoch 331/499
----------
Epoch331 Batch0 Loss: 134438.6406
Epoch331 Batch2 Loss: 165503.0625
Epoch331 Batch4 Loss: 150765.2656
train Loss: 134166.3975
Epoch 332/499
----------
Epoch332 Batch0 Loss: 150540.1719
Epoch332 Batch2 Loss: 161675.7656
Epoch332 Batch4 Loss: 133636.2500
train Loss: 136256.5088
Epoch 333/499
----------
Epoch333 Batch0 Loss: 140115.7031
Epoch333 Batch2 Loss: 139077.2812
Epoch333 Batch4 Loss: 127019.1484
train Loss: 136186.6663
Epoch 334/499
----------
Epoch334 Batch0 Loss: 130592.9219
Epoch334 Batch2 Loss: 103543.5156
Epoch334 Batch4 Loss: 115541.6172
train Loss: 130287.9075
Epoch 335/499
----------
Epoch335 Batch0 Loss: 164760.7812
Epoch335 Batch2 Loss: 87248.6719
Epoch335 Batch4 Loss: 11298

Epoch383 Batch4 Loss: 102675.5703
train Loss: 109010.0312
Epoch 384/499
----------
Epoch384 Batch0 Loss: 104081.2344
Epoch384 Batch2 Loss: 81066.2891
Epoch384 Batch4 Loss: 93458.6953
train Loss: 103598.2238
Epoch 385/499
----------
Epoch385 Batch0 Loss: 133411.0156
Epoch385 Batch2 Loss: 70254.9062
Epoch385 Batch4 Loss: 86813.4688
train Loss: 103102.9288
Epoch 386/499
----------
Epoch386 Batch0 Loss: 145850.2500
Epoch386 Batch2 Loss: 73991.5625
Epoch386 Batch4 Loss: 89372.6484
train Loss: 106961.2787
Epoch 387/499
----------
Epoch387 Batch0 Loss: 159793.0781
Epoch387 Batch2 Loss: 87669.3750
Epoch387 Batch4 Loss: 97684.7578
train Loss: 108051.8525
Epoch 388/499
----------
Epoch388 Batch0 Loss: 165865.1406
Epoch388 Batch2 Loss: 105047.1484
Epoch388 Batch4 Loss: 137429.5000
train Loss: 111256.8025
Epoch 389/499
----------
Epoch389 Batch0 Loss: 137996.6719
Epoch389 Batch2 Loss: 108556.4219
Epoch389 Batch4 Loss: 148749.9531
train Loss: 111211.9812
Epoch 390/499
----------
Epoch390 Batch0 Los

Epoch438 Batch4 Loss: 121869.6406
train Loss: 102664.1425
Epoch 439/499
----------
Epoch439 Batch0 Loss: 127247.9922
Epoch439 Batch2 Loss: 91517.7500
Epoch439 Batch4 Loss: 128825.2734
train Loss: 99395.0806
Epoch 440/499
----------
Epoch440 Batch0 Loss: 105364.9375
Epoch440 Batch2 Loss: 115390.5156
Epoch440 Batch4 Loss: 110063.1406
train Loss: 97419.2050
Epoch 441/499
----------
Epoch441 Batch0 Loss: 97971.4844
Epoch441 Batch2 Loss: 99122.1719
Epoch441 Batch4 Loss: 102128.7656
train Loss: 97881.2075
Epoch 442/499
----------
Epoch442 Batch0 Loss: 91969.4141
Epoch442 Batch2 Loss: 90835.8906
Epoch442 Batch4 Loss: 77022.7031
train Loss: 98163.9375
Epoch 443/499
----------
Epoch443 Batch0 Loss: 82144.7266
Epoch443 Batch2 Loss: 121133.5859
Epoch443 Batch4 Loss: 60294.9219
train Loss: 95410.4125
Epoch 444/499
----------
Epoch444 Batch0 Loss: 81729.6016
Epoch444 Batch2 Loss: 126148.7422
Epoch444 Batch4 Loss: 68247.2188
train Loss: 96184.9250
Epoch 445/499
----------
Epoch445 Batch0 Loss: 78545

Epoch494 Batch0 Loss: 66902.4609
Epoch494 Batch2 Loss: 111201.6953
Epoch494 Batch4 Loss: 56353.3359
train Loss: 80857.2525
Epoch 495/499
----------
Epoch495 Batch0 Loss: 68533.8359
Epoch495 Batch2 Loss: 120699.2266
Epoch495 Batch4 Loss: 67942.5312
train Loss: 81700.6913
Epoch 496/499
----------
Epoch496 Batch0 Loss: 73158.7109
Epoch496 Batch2 Loss: 123939.8203
Epoch496 Batch4 Loss: 77686.6172
train Loss: 79059.1562
Epoch 497/499
----------
Epoch497 Batch0 Loss: 103675.2031
Epoch497 Batch2 Loss: 100358.8125
Epoch497 Batch4 Loss: 79589.6016
train Loss: 82805.1650
Epoch 498/499
----------
Epoch498 Batch0 Loss: 104467.7031
Epoch498 Batch2 Loss: 88875.0859
Epoch498 Batch4 Loss: 93294.5938
train Loss: 83748.0650
Epoch 499/499
----------
Epoch499 Batch0 Loss: 94099.9219
Epoch499 Batch2 Loss: 85716.4531
Epoch499 Batch4 Loss: 82257.5547
train Loss: 83432.0663
