In [6]:
import sys
sys.path.append('./models/')
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from data_loader import Dataset,Options
import models.unet_normals as unet
from tensorboardX import SummaryWriter
# import OpenEXR, Imath

### Setup Options
Set the various parameters:
- dataroot: The folder where the training data is stored
- file_list: List of filenames of images for training
- batchSize: Batch size for model
- shuffle: If true, will shuffle the dataset
- phase: If 'train', then it's in training mode.
- num_epochs: Number of epochs to train the model for
- imsize: Dimensions of the image (square)
- num_classes: Num of classes in the output
- gpu: Which GPU device to use
- logs_path: The path where the log files (tensorboard) will be saved.

In [7]:
class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.batchSize = 32
        self.shuffle = True
        self.phase = 'train'
        self.num_epochs = 500
        self.imsize = 224
        self.num_classes = int(3)
        self.gpu = '0'
        self.logs_path = 'logs/exp9'
        self.use_pretrained = False

opt = OPT()

### Setup logging and dataloaders

In [8]:
###################### Options #############################
phase = opt.phase
device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")

###################### TensorBoardX #############################
if os.path.exists(opt.logs_path):
    raise Exception('The folder \"{}\" already exists! Define a new log path or delete old contents.'.format(opt.logs_path))
    
writer = SummaryWriter(opt.logs_path, comment='create-graph')
graph_created = False

###################### DataLoader #############################
dataloader = Dataset(opt)


shuffling the dataset


### Create the model
We use a UNet model. The last few layers of this model are modified to return a 3 channel image, containing the x,y,z values of surface normal vectors.

In [9]:
###################### ModelBuilder #############################
model = unet.Unet(num_classes=opt.num_classes)

# Load weights from checkpoint
if (opt.use_pretrained == True):
    checkpoint_path = 'logs/exp7/checkpoints/checkpoint.pth'
    model.load_state_dict(torch.load(checkpoint_path))

model = model.to(device)
model.train()

###################### Setup Optimazation #############################
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

###################### Loss fuction #############################
'''
@input: The 2 vectors whose cosine loss is to be calculated
The dimensions of the matrices are expected to be (batchSize, 3, imsize, imsize). 

@return: 
elementwise_mean: will return the sum of all losses divided by num of elements
none: The loss will be calculated to be of size (batchSize, imsize, imsize) containing cosine loss of each pixel
'''
def loss_fn(input_vec, target_vec, reduction='elementwise_mean'):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    loss_val = 1.0 - cos(input_vec, target_vec)
    if (reduction=='elementwise_mean'):
        return torch.mean(loss_val)
    elif (reduction=='none'):
        return loss_val
    else:
        raise Exception('Warning! The reduction is invalid. Please use \'elementwise_mean\' or \'none\''.format())


### Train the model


In [None]:
###################### Train Model #############################
# Calculate total iter_num
total_iter_num = 0

for epoch in range(opt.num_epochs):
    print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    running_loss = 0.0
    
    


    # Iterate over data.
    for i in range(int(dataloader.size()/opt.batchSize)):
        total_iter_num += 1
        
        # Get data
        inputs, labels =  dataloader.get_batch()
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        #ToDo: get labels into correct format
        
        ## Create Graph ##
        if graph_created == False:
            graph_created = True
            writer.add_graph(model, inputs, verbose=False)
        
        # Forward + Backward Prop
        optimizer.zero_grad()
        torch.set_grad_enabled(True)
        normal_vectors = model(inputs)
        normal_vectors_norm = nn.functional.normalize(normal_vectors, p=2, dim=1)
        
        loss = loss_fn(normal_vectors_norm, labels, reduction='elementwise_mean')
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item()
        writer.add_scalar('loss', loss.item(), total_iter_num)
        
        if (i % 10 == 0):
            print('Epoch{} Batch{} Loss: {:.4f}'.format(epoch, i, loss.item()))

    epoch_loss = running_loss / (dataloader.size()/opt.batchSize)
    writer.add_scalar('epoch_loss', epoch_loss, epoch)
    print('{} Loss: {:.4f}'.format(phase, epoch_loss))
    
    # Save the model checkpoint
    directory = opt.logs_path+'/checkpoints/'
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    if (epoch % 5 == 0):
        filename = opt.logs_path + '/checkpoints/checkpoint-epoch_{}.pth'.format(epoch,i)
        torch.save(model.state_dict(), filename)
        

# Save final Checkpoint
filename = opt.logs_path + '/checkpoints/checkpoint.pth'
torch.save(model.state_dict(), filename)


Epoch 0/499
----------
Epoch0 Batch0 Loss: 1.4844
Epoch0 Batch10 Loss: 0.6621
Epoch0 Batch20 Loss: 0.2250
Epoch0 Batch30 Loss: 0.2829
Epoch0 Batch40 Loss: 0.2167
Epoch0 Batch50 Loss: 0.2113
Epoch0 Batch60 Loss: 0.2824
Epoch0 Batch70 Loss: 0.2510
train Loss: 0.3921
Epoch 1/499
----------
shuffling the dataset
Epoch1 Batch0 Loss: 0.2787
Epoch1 Batch10 Loss: 0.2527
Epoch1 Batch20 Loss: 0.2630
Epoch1 Batch30 Loss: 0.3045
Epoch1 Batch40 Loss: 0.2829
Epoch1 Batch50 Loss: 0.2238
Epoch1 Batch60 Loss: 0.2352
Epoch1 Batch70 Loss: 0.2280
train Loss: 0.2519
Epoch 2/499
----------
Epoch2 Batch0 Loss: 0.2026
shuffling the dataset
Epoch2 Batch10 Loss: 0.2089
Epoch2 Batch20 Loss: 0.2300
Epoch2 Batch30 Loss: 0.2468
Epoch2 Batch40 Loss: 0.2358
Epoch2 Batch50 Loss: 0.2347
Epoch2 Batch60 Loss: 0.2541
Epoch2 Batch70 Loss: 0.2400
train Loss: 0.2470
Epoch 3/499
----------
Epoch3 Batch0 Loss: 0.2367
shuffling the dataset
Epoch3 Batch10 Loss: 0.1985
Epoch3 Batch20 Loss: 0.2217
Epoch3 Batch30 Loss: 0.2370
Epoch

Epoch28 Batch0 Loss: 0.0882
Epoch28 Batch10 Loss: 0.1710
Epoch28 Batch20 Loss: 0.1334
shuffling the dataset
Epoch28 Batch30 Loss: 0.0907
Epoch28 Batch40 Loss: 0.1218
Epoch28 Batch50 Loss: 0.1117
Epoch28 Batch60 Loss: 0.0944
Epoch28 Batch70 Loss: 0.0836
train Loss: 0.0907
Epoch 29/499
----------
Epoch29 Batch0 Loss: 0.0897
Epoch29 Batch10 Loss: 0.0821
Epoch29 Batch20 Loss: 0.0550
shuffling the dataset
Epoch29 Batch30 Loss: 0.0886
Epoch29 Batch40 Loss: 0.0742
Epoch29 Batch50 Loss: 0.0716
Epoch29 Batch60 Loss: 0.0908
Epoch29 Batch70 Loss: 0.1421
train Loss: 0.0917
Epoch 30/499
----------
Epoch30 Batch0 Loss: 0.1259
Epoch30 Batch10 Loss: 0.0940
Epoch30 Batch20 Loss: 0.0623
shuffling the dataset
Epoch30 Batch30 Loss: 0.0594
Epoch30 Batch40 Loss: 0.1186
Epoch30 Batch50 Loss: 0.1049
Epoch30 Batch60 Loss: 0.0959
Epoch30 Batch70 Loss: 0.1138
train Loss: 0.0895
Epoch 31/499
----------
Epoch31 Batch0 Loss: 0.0369
Epoch31 Batch10 Loss: 0.0882
Epoch31 Batch20 Loss: 0.0896
shuffling the dataset
Epoc

Epoch55 Batch70 Loss: 0.0658
train Loss: 0.0719
Epoch 56/499
----------
Epoch56 Batch0 Loss: 0.0940
Epoch56 Batch10 Loss: 0.0385
Epoch56 Batch20 Loss: 0.0747
Epoch56 Batch30 Loss: 0.0678
Epoch56 Batch40 Loss: 0.0855
shuffling the dataset
Epoch56 Batch50 Loss: 0.0618
Epoch56 Batch60 Loss: 0.0526
Epoch56 Batch70 Loss: 0.0561
train Loss: 0.0761
Epoch 57/499
----------
Epoch57 Batch0 Loss: 0.0636
Epoch57 Batch10 Loss: 0.0675
Epoch57 Batch20 Loss: 0.1135
Epoch57 Batch30 Loss: 0.0522
Epoch57 Batch40 Loss: 0.0542
shuffling the dataset
Epoch57 Batch50 Loss: 0.0809
Epoch57 Batch60 Loss: 0.0969
Epoch57 Batch70 Loss: 0.0845
train Loss: 0.0687
Epoch 58/499
----------
Epoch58 Batch0 Loss: 0.0707
Epoch58 Batch10 Loss: 0.0842
Epoch58 Batch20 Loss: 0.0567
Epoch58 Batch30 Loss: 0.0755
Epoch58 Batch40 Loss: 0.0774
shuffling the dataset
Epoch58 Batch50 Loss: 0.0545
Epoch58 Batch60 Loss: 0.0557
Epoch58 Batch70 Loss: 0.0636
train Loss: 0.0681
Epoch 59/499
----------
Epoch59 Batch0 Loss: 0.0636
Epoch59 Batc

Epoch85 Batch40 Loss: 0.0352
Epoch85 Batch50 Loss: 0.0174
Epoch85 Batch60 Loss: 0.0280
Epoch85 Batch70 Loss: 0.0240
train Loss: 0.0349
Epoch 86/499
----------
Epoch86 Batch0 Loss: 0.0433
shuffling the dataset
Epoch86 Batch10 Loss: 0.0250
Epoch86 Batch20 Loss: 0.0340
Epoch86 Batch30 Loss: 0.0241
Epoch86 Batch40 Loss: 0.0254
Epoch86 Batch50 Loss: 0.0294
Epoch86 Batch60 Loss: 0.0300
Epoch86 Batch70 Loss: 0.0439
train Loss: 0.0345
Epoch 87/499
----------
Epoch87 Batch0 Loss: 0.0361
shuffling the dataset
Epoch87 Batch10 Loss: 0.0379
Epoch87 Batch20 Loss: 0.0387
Epoch87 Batch30 Loss: 0.0185
Epoch87 Batch40 Loss: 0.0259
Epoch87 Batch50 Loss: 0.0265
Epoch87 Batch60 Loss: 0.0216
Epoch87 Batch70 Loss: 0.0304
train Loss: 0.0331
Epoch 88/499
----------
Epoch88 Batch0 Loss: 0.0422
shuffling the dataset
Epoch88 Batch10 Loss: 0.0253
Epoch88 Batch20 Loss: 0.0311
Epoch88 Batch30 Loss: 0.0338
Epoch88 Batch40 Loss: 0.0284
Epoch88 Batch50 Loss: 0.0378
Epoch88 Batch60 Loss: 0.0287
Epoch88 Batch70 Loss: 0.0

Epoch112 Batch70 Loss: 0.0157
train Loss: 0.0203
Epoch 113/499
----------
Epoch113 Batch0 Loss: 0.0236
Epoch113 Batch10 Loss: 0.0229
Epoch113 Batch20 Loss: 0.0210
shuffling the dataset
Epoch113 Batch30 Loss: 0.0169
Epoch113 Batch40 Loss: 0.0152
Epoch113 Batch50 Loss: 0.0275
Epoch113 Batch60 Loss: 0.0188
Epoch113 Batch70 Loss: 0.0289
train Loss: 0.0209
Epoch 114/499
----------
Epoch114 Batch0 Loss: 0.0197
Epoch114 Batch10 Loss: 0.0167
Epoch114 Batch20 Loss: 0.0199
shuffling the dataset
Epoch114 Batch30 Loss: 0.0235
Epoch114 Batch40 Loss: 0.0197
Epoch114 Batch50 Loss: 0.0163
Epoch114 Batch60 Loss: 0.0194
Epoch114 Batch70 Loss: 0.0113
train Loss: 0.0197
Epoch 115/499
----------
Epoch115 Batch0 Loss: 0.0204
Epoch115 Batch10 Loss: 0.0116
Epoch115 Batch20 Loss: 0.0156
shuffling the dataset
Epoch115 Batch30 Loss: 0.0128
Epoch115 Batch40 Loss: 0.0147
Epoch115 Batch50 Loss: 0.0231
Epoch115 Batch60 Loss: 0.0138
Epoch115 Batch70 Loss: 0.0155
train Loss: 0.0189
Epoch 116/499
----------
Epoch116 Ba

Epoch139 Batch60 Loss: 0.0161
Epoch139 Batch70 Loss: 0.0143
train Loss: 0.0138
Epoch 140/499
----------
Epoch140 Batch0 Loss: 0.0164
Epoch140 Batch10 Loss: 0.0120
Epoch140 Batch20 Loss: 0.0128
Epoch140 Batch30 Loss: 0.0116
Epoch140 Batch40 Loss: 0.0184
shuffling the dataset
Epoch140 Batch50 Loss: 0.0129
Epoch140 Batch60 Loss: 0.0145
Epoch140 Batch70 Loss: 0.0130
train Loss: 0.0148
Epoch 141/499
----------
Epoch141 Batch0 Loss: 0.0283
Epoch141 Batch10 Loss: 0.0096
Epoch141 Batch20 Loss: 0.0129
Epoch141 Batch30 Loss: 0.0115
Epoch141 Batch40 Loss: 0.0151
Epoch141 Batch50 Loss: 0.0064
shuffling the dataset
Epoch141 Batch60 Loss: 0.0160
Epoch141 Batch70 Loss: 0.0147
train Loss: 0.0144
Epoch 142/499
----------
Epoch142 Batch0 Loss: 0.0084
Epoch142 Batch10 Loss: 0.0094
Epoch142 Batch20 Loss: 0.0139
Epoch142 Batch30 Loss: 0.0170
Epoch142 Batch40 Loss: 0.0140
Epoch142 Batch50 Loss: 0.0106
shuffling the dataset
Epoch142 Batch60 Loss: 0.0250
Epoch142 Batch70 Loss: 0.0154
train Loss: 0.0148
Epoch 

Epoch166 Batch60 Loss: 0.0108
Epoch166 Batch70 Loss: 0.0057
train Loss: 0.0107
Epoch 167/499
----------
Epoch167 Batch0 Loss: 0.0091
shuffling the dataset
Epoch167 Batch10 Loss: 0.0090
Epoch167 Batch20 Loss: 0.0145
Epoch167 Batch30 Loss: 0.0102
Epoch167 Batch40 Loss: 0.0111
Epoch167 Batch50 Loss: 0.0096
Epoch167 Batch60 Loss: 0.0156
Epoch167 Batch70 Loss: 0.0085
train Loss: 0.0105
Epoch 168/499
----------
Epoch168 Batch0 Loss: 0.0120
shuffling the dataset
Epoch168 Batch10 Loss: 0.0109
Epoch168 Batch20 Loss: 0.0424
Epoch168 Batch30 Loss: 0.0088
Epoch168 Batch40 Loss: 0.0110
Epoch168 Batch50 Loss: 0.0143
Epoch168 Batch60 Loss: 0.0080
Epoch168 Batch70 Loss: 0.0086
train Loss: 0.0106
Epoch 169/499
----------
Epoch169 Batch0 Loss: 0.0119
shuffling the dataset
Epoch169 Batch10 Loss: 0.0355
Epoch169 Batch20 Loss: 0.0110
Epoch169 Batch30 Loss: 0.0077
Epoch169 Batch40 Loss: 0.0071
Epoch169 Batch50 Loss: 0.0055
Epoch169 Batch60 Loss: 0.0095
Epoch169 Batch70 Loss: 0.0124
train Loss: 0.0101
Epoch 

Epoch193 Batch50 Loss: 0.0096
Epoch193 Batch60 Loss: 0.0127
Epoch193 Batch70 Loss: 0.0060
train Loss: 0.0115
Epoch 194/499
----------
Epoch194 Batch0 Loss: 0.0092
Epoch194 Batch10 Loss: 0.0049
Epoch194 Batch20 Loss: 0.0075
shuffling the dataset
Epoch194 Batch30 Loss: 0.0064
Epoch194 Batch40 Loss: 0.0068
Epoch194 Batch50 Loss: 0.0092
Epoch194 Batch60 Loss: 0.0081
Epoch194 Batch70 Loss: 0.0100
train Loss: 0.0095
Epoch 195/499
----------
Epoch195 Batch0 Loss: 0.0092
Epoch195 Batch10 Loss: 0.0092
Epoch195 Batch20 Loss: 0.0088
shuffling the dataset
Epoch195 Batch30 Loss: 0.0054
Epoch195 Batch40 Loss: 0.0064
Epoch195 Batch50 Loss: 0.0093
Epoch195 Batch60 Loss: 0.0086
Epoch195 Batch70 Loss: 0.0097
train Loss: 0.0097
Epoch 196/499
----------
Epoch196 Batch0 Loss: 0.0073
Epoch196 Batch10 Loss: 0.0089
Epoch196 Batch20 Loss: 0.0052
shuffling the dataset
Epoch196 Batch30 Loss: 0.0079
Epoch196 Batch40 Loss: 0.0384
Epoch196 Batch50 Loss: 0.0050
Epoch196 Batch60 Loss: 0.0059
Epoch196 Batch70 Loss: 0.

shuffling the dataset
Epoch220 Batch50 Loss: 0.0036
Epoch220 Batch60 Loss: 0.0080
Epoch220 Batch70 Loss: 0.0083
train Loss: 0.0078
Epoch 221/499
----------
Epoch221 Batch0 Loss: 0.0051
Epoch221 Batch10 Loss: 0.0062
Epoch221 Batch20 Loss: 0.0061
Epoch221 Batch30 Loss: 0.0063
Epoch221 Batch40 Loss: 0.0060
shuffling the dataset
Epoch221 Batch50 Loss: 0.0069
Epoch221 Batch60 Loss: 0.0079
Epoch221 Batch70 Loss: 0.0067
train Loss: 0.0070
Epoch 222/499
----------
Epoch222 Batch0 Loss: 0.0053
Epoch222 Batch10 Loss: 0.0090
Epoch222 Batch20 Loss: 0.0065
Epoch222 Batch30 Loss: 0.0057
Epoch222 Batch40 Loss: 0.0043
shuffling the dataset
Epoch222 Batch50 Loss: 0.0142
Epoch222 Batch60 Loss: 0.0043
Epoch222 Batch70 Loss: 0.0085
train Loss: 0.0081
Epoch 223/499
----------
Epoch223 Batch0 Loss: 0.0050
Epoch223 Batch10 Loss: 0.0060
Epoch223 Batch20 Loss: 0.0088
Epoch223 Batch30 Loss: 0.0108
Epoch223 Batch40 Loss: 0.0145
Epoch223 Batch50 Loss: 0.0033
shuffling the dataset
Epoch223 Batch60 Loss: 0.0168
Epo