In [1]:
import sys
sys.path.append('./models/')
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from data_loader import Dataset,Options
import models.unet_normals as unet
from tensorboardX import SummaryWriter
# import OpenEXR, Imath

### Setup Options
Set the various parameters:
- dataroot: The folder where the training data is stored
- file_list: List of filenames of images for training
- batchSize: Batch size for model
- shuffle: If true, will shuffle the dataset
- phase: If 'train', then it's in training mode.
- num_epochs: Number of epochs to train the model for
- imsize: Dimensions of the image (square)
- num_classes: Num of classes in the output
- gpu: Which GPU device to use
- logs_path: The path where the log files (tensorboard) will be saved.

In [2]:
class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.batchSize = 32
        self.shuffle = True
        self.phase = 'train'
        self.num_epochs = 500
        self.imsize = 224
        self.num_classes = int(3)
        self.gpu = '0'
        self.logs_path = 'logs/exp11-3'
        self.use_pretrained = True

opt = OPT()

### Setup logging and dataloaders

In [3]:
###################### Options #############################
phase = opt.phase
device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")

###################### TensorBoardX #############################
if os.path.exists(opt.logs_path):
    raise Exception('The folder \"{}\" already exists! Define a new log path or delete old contents.'.format(opt.logs_path))
    
writer = SummaryWriter(opt.logs_path, comment='create-graph')
graph_created = False

###################### DataLoader #############################
dataloader = Dataset(opt)


shuffling the dataset


### Create the model
We use a UNet model. The last few layers of this model are modified to return a 3 channel image, containing the x,y,z values of surface normal vectors.

In [4]:
###################### ModelBuilder #############################
model = unet.Unet(num_classes=opt.num_classes)

# Load weights from checkpoint
if (opt.use_pretrained == True):
    checkpoint_path = 'logs/exp11-2/checkpoints/checkpoint-epoch_1300.pth'
    model.load_state_dict(torch.load(checkpoint_path))

model = model.to(device)
model.train()

###################### Setup Optimazation #############################
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

###################### Loss fuction #############################
'''
@input: The 2 vectors whose cosine loss is to be calculated
The dimensions of the matrices are expected to be (batchSize, 3, imsize, imsize). 

@return: 
elementwise_mean: will return the sum of all losses divided by num of elements
none: The loss will be calculated to be of size (batchSize, imsize, imsize) containing cosine loss of each pixel
'''
def loss_fn(input_vec, target_vec, reduction='elementwise_mean'):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    loss_val = 1.0 - cos(input_vec, target_vec)
    if (reduction=='elementwise_mean'):
        return torch.mean(loss_val)
    elif (reduction=='none'):
        return loss_val
    else:
        raise Exception('Warning! The reduction is invalid. Please use \'elementwise_mean\' or \'none\''.format())


### Train the model


In [None]:
###################### Train Model #############################
# Calculate total iter_num
total_iter_num = 98000

for epoch in range(1305, 1305+1500):
    print('Epoch {}/{}'.format(epoch, opt.num_epochs+1500 - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    running_loss = 0.0
    
    


    # Iterate over data.
    for i in range(int(dataloader.size()/opt.batchSize)):
        total_iter_num += 1
        
        # Get data
        inputs, labels =  dataloader.get_batch()
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        #ToDo: get labels into correct format
        
        ## Create Graph ##
        if graph_created == False:
            graph_created = True
            writer.add_graph(model, inputs, verbose=False)
        
        # Forward + Backward Prop
        optimizer.zero_grad()
        torch.set_grad_enabled(True)
        normal_vectors = model(inputs)
        normal_vectors_norm = nn.functional.normalize(normal_vectors, p=2, dim=1)
        
        loss = loss_fn(normal_vectors_norm, labels, reduction='elementwise_mean')
        ### Scale Loss by a factor of 100 ###
        loss = loss*100
        #####################################
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item()
        writer.add_scalar('loss', loss.item(), total_iter_num)
        
        if (i % 10 == 0):
            print('Epoch{} Batch{} Loss: {:.4f}'.format(epoch, i, loss.item()))

    #exp_lr_scheduler.step() # This is for the LR Scheduler
    epoch_loss = running_loss / (dataloader.size()/opt.batchSize)
    writer.add_scalar('epoch_loss', epoch_loss, epoch)
    print('{} Loss: {:.4f}'.format(phase, epoch_loss))
    
    # Save the model checkpoint
    directory = opt.logs_path+'/checkpoints/'
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    if (epoch % 5 == 0):
        filename = opt.logs_path + '/checkpoints/checkpoint-epoch_{}.pth'.format(epoch,i)
        torch.save(model.state_dict(), filename)
        

# Save final Checkpoint
filename = opt.logs_path + '/checkpoints/checkpoint.pth'
torch.save(model.state_dict(), filename)


Epoch 1305/1999
----------
Epoch1305 Batch0 Loss: 0.1084
Epoch1305 Batch10 Loss: 27.5719
Epoch1305 Batch20 Loss: 17.2170
Epoch1305 Batch30 Loss: 20.7190
Epoch1305 Batch40 Loss: 17.4696
Epoch1305 Batch50 Loss: 23.6811
Epoch1305 Batch60 Loss: 24.2594
Epoch1305 Batch70 Loss: 19.0002
shuffling the dataset
train Loss: 22.7854
Epoch 1306/1999
----------
Epoch1306 Batch0 Loss: 19.3001
Epoch1306 Batch10 Loss: 23.9740
Epoch1306 Batch20 Loss: 22.1439
Epoch1306 Batch30 Loss: 25.1439
Epoch1306 Batch40 Loss: 22.3646
Epoch1306 Batch50 Loss: 23.6209
Epoch1306 Batch60 Loss: 20.5597
Epoch1306 Batch70 Loss: 22.4340
shuffling the dataset
train Loss: 22.3486
Epoch 1307/1999
----------
Epoch1307 Batch0 Loss: 22.2476
Epoch1307 Batch10 Loss: 22.1259
Epoch1307 Batch20 Loss: 20.6062
Epoch1307 Batch30 Loss: 19.8032
Epoch1307 Batch40 Loss: 21.0468
Epoch1307 Batch50 Loss: 19.9323
Epoch1307 Batch60 Loss: 21.5210
Epoch1307 Batch70 Loss: 22.5190
shuffling the dataset
train Loss: 22.1677
Epoch 1308/1999
----------
Ep

Epoch1330 Batch30 Loss: 10.3037
Epoch1330 Batch40 Loss: 9.1361
Epoch1330 Batch50 Loss: 14.2845
Epoch1330 Batch60 Loss: 14.4471
Epoch1330 Batch70 Loss: 12.3994
shuffling the dataset
train Loss: 11.9802
Epoch 1331/1999
----------
Epoch1331 Batch0 Loss: 11.7412
Epoch1331 Batch10 Loss: 12.6361
Epoch1331 Batch20 Loss: 9.5441
Epoch1331 Batch30 Loss: 9.4450
Epoch1331 Batch40 Loss: 10.5854
Epoch1331 Batch50 Loss: 14.6807
Epoch1331 Batch60 Loss: 11.4438
Epoch1331 Batch70 Loss: 10.3526
shuffling the dataset
train Loss: 11.5274
Epoch 1332/1999
----------
Epoch1332 Batch0 Loss: 10.5276
Epoch1332 Batch10 Loss: 11.6453
Epoch1332 Batch20 Loss: 9.9441
Epoch1332 Batch30 Loss: 10.0518
Epoch1332 Batch40 Loss: 10.9913
Epoch1332 Batch50 Loss: 11.8808
Epoch1332 Batch60 Loss: 11.2748
Epoch1332 Batch70 Loss: 10.6298
shuffling the dataset
train Loss: 10.9759
Epoch 1333/1999
----------
Epoch1333 Batch0 Loss: 10.6528
Epoch1333 Batch10 Loss: 10.4520
Epoch1333 Batch20 Loss: 11.2228
Epoch1333 Batch30 Loss: 10.8621


Epoch1356 Batch20 Loss: 3.6835
Epoch1356 Batch30 Loss: 2.7145
Epoch1356 Batch40 Loss: 3.6454
Epoch1356 Batch50 Loss: 3.6761
Epoch1356 Batch60 Loss: 3.2053
Epoch1356 Batch70 Loss: 2.9093
shuffling the dataset
train Loss: 3.2530
Epoch 1357/1999
----------
Epoch1357 Batch0 Loss: 3.1587
Epoch1357 Batch10 Loss: 2.4689
Epoch1357 Batch20 Loss: 2.8471
Epoch1357 Batch30 Loss: 3.1814
Epoch1357 Batch40 Loss: 3.4468
Epoch1357 Batch50 Loss: 2.5087
Epoch1357 Batch60 Loss: 3.1368
Epoch1357 Batch70 Loss: 3.7118
shuffling the dataset
train Loss: 3.1777
Epoch 1358/1999
----------
Epoch1358 Batch0 Loss: 2.4847
Epoch1358 Batch10 Loss: 3.2685
Epoch1358 Batch20 Loss: 2.7303
Epoch1358 Batch30 Loss: 3.1612
Epoch1358 Batch40 Loss: 2.4471
Epoch1358 Batch50 Loss: 2.8813
Epoch1358 Batch60 Loss: 3.4253
Epoch1358 Batch70 Loss: 2.9333
shuffling the dataset
train Loss: 3.0883
Epoch 1359/1999
----------
Epoch1359 Batch0 Loss: 3.0121
Epoch1359 Batch10 Loss: 3.4219
Epoch1359 Batch20 Loss: 3.1088
Epoch1359 Batch30 Loss: 

Epoch1382 Batch20 Loss: 1.6044
Epoch1382 Batch30 Loss: 1.6405
Epoch1382 Batch40 Loss: 1.7521
Epoch1382 Batch50 Loss: 1.7262
Epoch1382 Batch60 Loss: 1.4808
Epoch1382 Batch70 Loss: 1.5363
shuffling the dataset
train Loss: 1.9421
Epoch 1383/1999
----------
Epoch1383 Batch0 Loss: 2.1242
Epoch1383 Batch10 Loss: 1.6787
Epoch1383 Batch20 Loss: 1.8124
Epoch1383 Batch30 Loss: 1.9514
Epoch1383 Batch40 Loss: 1.8510
Epoch1383 Batch50 Loss: 2.8162
Epoch1383 Batch60 Loss: 1.6658
Epoch1383 Batch70 Loss: 1.7797
shuffling the dataset
train Loss: 1.9228
Epoch 1384/1999
----------
Epoch1384 Batch0 Loss: 2.1446
Epoch1384 Batch10 Loss: 1.2827
Epoch1384 Batch20 Loss: 1.9388
Epoch1384 Batch30 Loss: 2.0104
Epoch1384 Batch40 Loss: 2.2169
Epoch1384 Batch50 Loss: 2.0473
Epoch1384 Batch60 Loss: 1.6936
Epoch1384 Batch70 Loss: 1.5134
shuffling the dataset
train Loss: 1.9024
Epoch 1385/1999
----------
Epoch1385 Batch0 Loss: 1.1956
Epoch1385 Batch10 Loss: 2.0939
Epoch1385 Batch20 Loss: 1.7422
Epoch1385 Batch30 Loss: 

Epoch1408 Batch30 Loss: 1.9673
Epoch1408 Batch40 Loss: 1.5470
Epoch1408 Batch50 Loss: 2.1826
Epoch1408 Batch60 Loss: 1.4797
Epoch1408 Batch70 Loss: 1.4733
shuffling the dataset
train Loss: 1.6739
Epoch 1409/1999
----------
Epoch1409 Batch0 Loss: 1.4037
Epoch1409 Batch10 Loss: 1.8244
Epoch1409 Batch20 Loss: 1.5540
Epoch1409 Batch30 Loss: 1.5608
Epoch1409 Batch40 Loss: 1.4899
Epoch1409 Batch50 Loss: 1.4051
Epoch1409 Batch60 Loss: 1.3704
Epoch1409 Batch70 Loss: 1.6496
shuffling the dataset
train Loss: 1.6327
Epoch 1410/1999
----------
Epoch1410 Batch0 Loss: 1.5083
Epoch1410 Batch10 Loss: 1.3867
Epoch1410 Batch20 Loss: 1.1525
Epoch1410 Batch30 Loss: 1.5076
Epoch1410 Batch40 Loss: 1.2809
Epoch1410 Batch50 Loss: 1.6872
Epoch1410 Batch60 Loss: 1.4802
Epoch1410 Batch70 Loss: 1.9504
shuffling the dataset
train Loss: 1.5576
Epoch 1411/1999
----------
Epoch1411 Batch0 Loss: 1.3208
Epoch1411 Batch10 Loss: 1.5008
Epoch1411 Batch20 Loss: 1.5114
Epoch1411 Batch30 Loss: 1.2505
Epoch1411 Batch40 Loss: 

Epoch1434 Batch30 Loss: 7.0893
Epoch1434 Batch40 Loss: 6.0969
Epoch1434 Batch50 Loss: 5.8757
Epoch1434 Batch60 Loss: 5.7600
Epoch1434 Batch70 Loss: 6.6441
shuffling the dataset
train Loss: 8.2445
Epoch 1435/1999
----------
Epoch1435 Batch0 Loss: 5.8791
Epoch1435 Batch10 Loss: 4.4160
Epoch1435 Batch20 Loss: 6.1179
Epoch1435 Batch30 Loss: 3.9452
Epoch1435 Batch40 Loss: 3.0999
Epoch1435 Batch50 Loss: 3.2854
Epoch1435 Batch60 Loss: 3.1575
Epoch1435 Batch70 Loss: 4.0002
shuffling the dataset
train Loss: 3.9182
Epoch 1436/1999
----------
Epoch1436 Batch0 Loss: 2.3439
Epoch1436 Batch10 Loss: 2.8768
Epoch1436 Batch20 Loss: 2.2728
Epoch1436 Batch30 Loss: 1.8285
Epoch1436 Batch40 Loss: 2.0075
Epoch1436 Batch50 Loss: 2.1179
Epoch1436 Batch60 Loss: 3.8784
Epoch1436 Batch70 Loss: 1.8743
shuffling the dataset
train Loss: 2.1525
Epoch 1437/1999
----------
Epoch1437 Batch0 Loss: 1.7420
Epoch1437 Batch10 Loss: 1.6396
Epoch1437 Batch20 Loss: 3.8707
Epoch1437 Batch30 Loss: 1.6958
Epoch1437 Batch40 Loss: 

Epoch1460 Batch40 Loss: 0.8151
Epoch1460 Batch50 Loss: 1.0681
Epoch1460 Batch60 Loss: 0.8059
Epoch1460 Batch70 Loss: 0.9243
shuffling the dataset
train Loss: 0.9550
Epoch 1461/1999
----------
Epoch1461 Batch0 Loss: 0.9755
Epoch1461 Batch10 Loss: 0.8617
Epoch1461 Batch20 Loss: 0.8899
Epoch1461 Batch30 Loss: 0.6088
Epoch1461 Batch40 Loss: 0.8261
Epoch1461 Batch50 Loss: 1.0946
Epoch1461 Batch60 Loss: 0.7882
Epoch1461 Batch70 Loss: 0.9338
shuffling the dataset
train Loss: 0.9658
Epoch 1462/1999
----------
Epoch1462 Batch0 Loss: 0.7302
Epoch1462 Batch10 Loss: 0.8605
Epoch1462 Batch20 Loss: 1.0012
Epoch1462 Batch30 Loss: 3.7837
Epoch1462 Batch40 Loss: 0.8529
Epoch1462 Batch50 Loss: 0.6332
Epoch1462 Batch60 Loss: 0.8124
Epoch1462 Batch70 Loss: 0.7097
shuffling the dataset
train Loss: 0.9740
Epoch 1463/1999
----------
Epoch1463 Batch0 Loss: 0.8716
Epoch1463 Batch10 Loss: 0.8993
Epoch1463 Batch20 Loss: 0.8774
Epoch1463 Batch30 Loss: 0.8303
Epoch1463 Batch40 Loss: 0.9743
Epoch1463 Batch50 Loss: 

Epoch1486 Batch50 Loss: 0.8174
Epoch1486 Batch60 Loss: 0.5966
Epoch1486 Batch70 Loss: 0.5911
shuffling the dataset
train Loss: 0.8265
Epoch 1487/1999
----------
Epoch1487 Batch0 Loss: 0.6467
Epoch1487 Batch10 Loss: 0.8067
Epoch1487 Batch20 Loss: 0.8452
Epoch1487 Batch30 Loss: 0.9912
Epoch1487 Batch40 Loss: 0.5429
Epoch1487 Batch50 Loss: 0.7645
Epoch1487 Batch60 Loss: 0.5063
Epoch1487 Batch70 Loss: 0.7509
shuffling the dataset
train Loss: 0.8182
Epoch 1488/1999
----------
Epoch1488 Batch0 Loss: 0.6959
Epoch1488 Batch10 Loss: 0.6996
Epoch1488 Batch20 Loss: 3.5101
Epoch1488 Batch30 Loss: 0.7399
Epoch1488 Batch40 Loss: 0.6984
Epoch1488 Batch50 Loss: 0.7435
Epoch1488 Batch60 Loss: 0.7832
Epoch1488 Batch70 Loss: 0.9552
shuffling the dataset
train Loss: 0.8163
Epoch 1489/1999
----------
Epoch1489 Batch0 Loss: 0.5115
Epoch1489 Batch10 Loss: 0.5910
Epoch1489 Batch20 Loss: 0.6667
Epoch1489 Batch30 Loss: 0.7017
Epoch1489 Batch40 Loss: 0.9031
Epoch1489 Batch50 Loss: 0.8952
Epoch1489 Batch60 Loss: 