In [1]:
import sys
sys.path.append('./models/')
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from data_loader import Dataset
import models.unet_normals as unet
from tensorboardX import SummaryWriter
# import OpenEXR, Imath
import imageio

### Setup Options
Set the various parameters:
- dataroot: The folder where the training data is stored
- file_list: List of filenames of images for training
- batchSize: Batch size for model
- shuffle: If true, will shuffle the dataset
- phase: If 'train', then it's in training mode.
- num_epochs: Number of epochs to train the model for
- imsize: Dimensions of the image (square)
- num_classes: Num of classes in the output
- gpu: Which GPU device to use
- logs_path: The path where the log files (tensorboard) will be saved.

In [2]:
class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.batchSize = 16
        self.shuffle = True
        self.phase = 'train'
        self.num_epochs = 500
        self.imsize = 224
        self.num_classes = int(2)
        self.gpu = '0'
        self.logs_path = 'logs/exp2'
        self.use_pretrained = True

opt = OPT()

### Setup logging and dataloaders

In [3]:
###################### Options #############################
phase = opt.phase
device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")

###################### TensorBoardX #############################
# if os.path.exists(opt.logs_path):
#     raise Exception('The folder \"{}\" already exists! Define a new log path or delete old contents.'.format(opt.logs_path))
    
writer = SummaryWriter(opt.logs_path, comment='create-graph')
graph_created = False

###################### DataLoader #############################
dataloader = Dataset(opt)


shuffling the dataset


### Create the model
We use a UNet model. The last few layers of this model are modified to return a 3 channel image, containing the x,y,z values of surface normal vectors.

In [4]:
###################### ModelBuilder #############################
model = unet.Unet(num_classes=opt.num_classes)
criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

# Load weights from checkpoint
if (opt.use_pretrained == True):
    checkpoint_path = 'logs/exp1/checkpoints/checkpoint-epoch_405.pth'
    model.load_state_dict(torch.load(checkpoint_path))

model = model.to(device)
model.train()

###################### Setup Optimazation #############################
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

###################### Loss fuction #############################
def flatten_logits(logits, number_of_classes):
    """Flattens the logits batch except for the logits dimension"""
    logits_permuted = logits.permute(0, 2, 3, 1)
    logits_permuted_cont = logits_permuted.contiguous()
    logits_flatten = logits_permuted_cont.view(-1, number_of_classes)
    return logits_flatten

def flatten_annotations(annotations):
    return annotations.view(-1)

def get_valid_annotations_index(flatten_annotations, mask_out_value=255):
    return torch.squeeze( torch.nonzero((flatten_annotations != mask_out_value )), 1)

### Train the model


In [5]:
###################### Train Model #############################
# Calculate total iter_num
total_iter_num = 2467

for epoch in range(411, 411+opt.num_epochs):
    print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    running_loss = 0.0
    
    
    # Iterate over data.
    for i in range(int(dataloader.size()/opt.batchSize)):
        total_iter_num += 1
        
        # Get data
        inputs, labels =  dataloader.get_batch()
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # We need to flatten annotations and logits to apply index of valid annotations.
        anno_flatten = flatten_annotations(labels)
        index = get_valid_annotations_index(anno_flatten, mask_out_value=255)
        anno_flatten_valid = torch.index_select(anno_flatten, 0, index)
        
        ## Create Graph ##
        if graph_created == False:
            graph_created = True
            writer.add_graph(model, inputs, verbose=False)
        
        # Forward Prop
        optimizer.zero_grad()
        torch.set_grad_enabled(True)
        logits = model(inputs)
        
        # Calculate Loss
        logits_flatten = flatten_logits(logits, number_of_classes=opt.num_classes)
        logits_flatten_valid = torch.index_select(logits_flatten, 0, index)
        loss = criterion(logits_flatten_valid, anno_flatten_valid)
        
        # Backward Prop
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item()
        writer.add_scalar('loss', loss.item(), total_iter_num)
        
        if (i % 2 == 0):
            print('Epoch{} Batch{} Loss: {:.4f}'.format(epoch, i, loss.item()))

    epoch_loss = running_loss / (dataloader.size()/opt.batchSize)
    writer.add_scalar('epoch_loss', epoch_loss, epoch)
    print('{} Loss: {:.4f}'.format(phase, epoch_loss))
    
    # Save the model checkpoint
    directory = opt.logs_path+'/checkpoints/'
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    if (epoch % 5 == 0):
        filename = opt.logs_path + '/checkpoints/checkpoint-epoch_{}.pth'.format(epoch,i)
        torch.save(model.state_dict(), filename)
        

  

Epoch 411/499
----------
Epoch411 Batch0 Loss: 24014.5840
Epoch411 Batch2 Loss: 30918.6602
Epoch411 Batch4 Loss: 57753.9141
train Loss: 50896.9784
Epoch 412/499
----------
shuffling the dataset
Epoch412 Batch0 Loss: 32984.0469
Epoch412 Batch2 Loss: 51051.0391
Epoch412 Batch4 Loss: 54358.9141
train Loss: 42741.1125
Epoch 413/499
----------
shuffling the dataset
Epoch413 Batch0 Loss: 45512.9883
Epoch413 Batch2 Loss: 41568.8477
Epoch413 Batch4 Loss: 29749.9844
train Loss: 38696.4941
Epoch 414/499
----------
shuffling the dataset
Epoch414 Batch0 Loss: 34415.5977
Epoch414 Batch2 Loss: 36843.5039
Epoch414 Batch4 Loss: 17672.2773
train Loss: 34737.9544
Epoch 415/499
----------
shuffling the dataset
Epoch415 Batch0 Loss: 21073.9160
Epoch415 Batch2 Loss: 45369.1875
Epoch415 Batch4 Loss: 30016.0117
train Loss: 29365.3391
Epoch 416/499
----------
Epoch416 Batch0 Loss: 22702.3105
shuffling the dataset
Epoch416 Batch2 Loss: 17862.7070
Epoch416 Batch4 Loss: 25138.8984
train Loss: 26303.1103
Epoch 41

train Loss: 19246.5250
Epoch 460/499
----------
Epoch460 Batch0 Loss: 17640.6992
Epoch460 Batch2 Loss: 15992.4561
Epoch460 Batch4 Loss: 27429.5039
shuffling the dataset
train Loss: 18975.7098
Epoch 461/499
----------
Epoch461 Batch0 Loss: 14023.4863
Epoch461 Batch2 Loss: 25314.2422
Epoch461 Batch4 Loss: 26069.8867
train Loss: 18998.0709
Epoch 462/499
----------
shuffling the dataset
Epoch462 Batch0 Loss: 16489.0410
Epoch462 Batch2 Loss: 19440.7070
Epoch462 Batch4 Loss: 20333.5391
train Loss: 19615.8581
Epoch 463/499
----------
shuffling the dataset
Epoch463 Batch0 Loss: 14297.1455
Epoch463 Batch2 Loss: 14138.3203
Epoch463 Batch4 Loss: 18427.3867
train Loss: 18739.7320
Epoch 464/499
----------
shuffling the dataset
Epoch464 Batch0 Loss: 20317.9609
Epoch464 Batch2 Loss: 16995.1543
Epoch464 Batch4 Loss: 23076.9668
train Loss: 19251.6334
Epoch 465/499
----------
shuffling the dataset
Epoch465 Batch0 Loss: 22094.6309
Epoch465 Batch2 Loss: 18682.4707
Epoch465 Batch4 Loss: 19651.7090
train Lo

train Loss: 18234.2836
Epoch 509/499
----------
Epoch509 Batch0 Loss: 20100.2285
Epoch509 Batch2 Loss: 14096.3086
Epoch509 Batch4 Loss: 15514.2969
shuffling the dataset
train Loss: 15448.0827
Epoch 510/499
----------
Epoch510 Batch0 Loss: 14056.1611
Epoch510 Batch2 Loss: 10259.4141
Epoch510 Batch4 Loss: 22803.6738
shuffling the dataset
train Loss: 15683.4144
Epoch 511/499
----------
Epoch511 Batch0 Loss: 13170.0635
Epoch511 Batch2 Loss: 14827.2637
Epoch511 Batch4 Loss: 11317.6992
train Loss: 15257.8714
Epoch 512/499
----------
shuffling the dataset
Epoch512 Batch0 Loss: 10377.0439
Epoch512 Batch2 Loss: 15412.4863
Epoch512 Batch4 Loss: 15908.9512
train Loss: 14807.8194
Epoch 513/499
----------
shuffling the dataset
Epoch513 Batch0 Loss: 11034.1914
Epoch513 Batch2 Loss: 16755.4570
Epoch513 Batch4 Loss: 9373.5527
train Loss: 13524.0623
Epoch 514/499
----------
shuffling the dataset
Epoch514 Batch0 Loss: 20788.6172
Epoch514 Batch2 Loss: 20590.1641
Epoch514 Batch4 Loss: 11381.9414
train Los

train Loss: 13245.3581
Epoch 558/499
----------
Epoch558 Batch0 Loss: 18090.0430
Epoch558 Batch2 Loss: 17582.8203
Epoch558 Batch4 Loss: 13758.9229
shuffling the dataset
train Loss: 14086.5508
Epoch 559/499
----------
Epoch559 Batch0 Loss: 14885.0479
Epoch559 Batch2 Loss: 13656.2090
Epoch559 Batch4 Loss: 7759.0327
shuffling the dataset
train Loss: 11937.6543
Epoch 560/499
----------
Epoch560 Batch0 Loss: 12279.1416
Epoch560 Batch2 Loss: 15011.6680
Epoch560 Batch4 Loss: 16130.4365
shuffling the dataset
train Loss: 12990.3364
Epoch 561/499
----------
Epoch561 Batch0 Loss: 9330.2510
Epoch561 Batch2 Loss: 13766.3154
Epoch561 Batch4 Loss: 15321.4893
train Loss: 12108.1702
Epoch 562/499
----------
shuffling the dataset
Epoch562 Batch0 Loss: 11555.2148
Epoch562 Batch2 Loss: 10505.3457
Epoch562 Batch4 Loss: 19568.9512
train Loss: 12557.6739
Epoch 563/499
----------
shuffling the dataset
Epoch563 Batch0 Loss: 9585.8252
Epoch563 Batch2 Loss: 9510.2949
Epoch563 Batch4 Loss: 13469.5869
train Loss: 

train Loss: 11446.6784
Epoch 607/499
----------
Epoch607 Batch0 Loss: 15105.9316
Epoch607 Batch2 Loss: 10706.2236
Epoch607 Batch4 Loss: 12817.7402
shuffling the dataset
train Loss: 11977.9925
Epoch 608/499
----------
Epoch608 Batch0 Loss: 12658.3203
Epoch608 Batch2 Loss: 12255.4785
Epoch608 Batch4 Loss: 12180.7832
shuffling the dataset
train Loss: 10912.0884
Epoch 609/499
----------
Epoch609 Batch0 Loss: 9738.4385
Epoch609 Batch2 Loss: 11498.6152
Epoch609 Batch4 Loss: 14327.1016
shuffling the dataset
train Loss: 11279.3427
Epoch 610/499
----------
Epoch610 Batch0 Loss: 9586.9414
Epoch610 Batch2 Loss: 9511.2041
Epoch610 Batch4 Loss: 15859.9307
shuffling the dataset
train Loss: 11105.9966
Epoch 611/499
----------
Epoch611 Batch0 Loss: 14110.4766
Epoch611 Batch2 Loss: 14357.6816
Epoch611 Batch4 Loss: 10147.3086
train Loss: 10640.1177
Epoch 612/499
----------
shuffling the dataset
Epoch612 Batch0 Loss: 9306.7881
Epoch612 Batch2 Loss: 14848.7705
Epoch612 Batch4 Loss: 7734.8896
train Loss: 9

Epoch656 Batch0 Loss: 6246.4346
Epoch656 Batch2 Loss: 10850.4658
shuffling the dataset
Epoch656 Batch4 Loss: 11612.1914
train Loss: 9445.6611
Epoch 657/499
----------
Epoch657 Batch0 Loss: 10370.4863
Epoch657 Batch2 Loss: 8282.4102
Epoch657 Batch4 Loss: 9181.8555
shuffling the dataset
train Loss: 8800.0930
Epoch 658/499
----------
Epoch658 Batch0 Loss: 13994.4736
Epoch658 Batch2 Loss: 9232.5283
Epoch658 Batch4 Loss: 7377.8745
shuffling the dataset
train Loss: 9889.3301
Epoch 659/499
----------
Epoch659 Batch0 Loss: 5958.0713
Epoch659 Batch2 Loss: 7716.4067
Epoch659 Batch4 Loss: 10799.3477
shuffling the dataset
train Loss: 9240.3866
Epoch 660/499
----------
Epoch660 Batch0 Loss: 16178.2725
Epoch660 Batch2 Loss: 6497.5811
Epoch660 Batch4 Loss: 14184.0254
shuffling the dataset
train Loss: 9446.2759
Epoch 661/499
----------
Epoch661 Batch0 Loss: 6836.8125
Epoch661 Batch2 Loss: 9950.1250
Epoch661 Batch4 Loss: 11626.1299
train Loss: 9858.4358
Epoch 662/499
----------
shuffling the dataset
Ep

shuffling the dataset
Epoch705 Batch4 Loss: 7270.1172
train Loss: 6638.0120
Epoch 706/499
----------
Epoch706 Batch0 Loss: 4600.6460
Epoch706 Batch2 Loss: 10230.2832
shuffling the dataset
Epoch706 Batch4 Loss: 7021.0098
train Loss: 6502.9798
Epoch 707/499
----------
Epoch707 Batch0 Loss: 4907.7246
Epoch707 Batch2 Loss: 7152.8545
Epoch707 Batch4 Loss: 5622.7944
shuffling the dataset
train Loss: 5820.6613
Epoch 708/499
----------
Epoch708 Batch0 Loss: 7380.8081
Epoch708 Batch2 Loss: 7737.0366
Epoch708 Batch4 Loss: 7875.7900
shuffling the dataset
train Loss: 6517.7902
Epoch 709/499
----------
Epoch709 Batch0 Loss: 9564.9902
Epoch709 Batch2 Loss: 4486.6113
Epoch709 Batch4 Loss: 5620.5059
shuffling the dataset
train Loss: 6100.5180
Epoch 710/499
----------
Epoch710 Batch0 Loss: 3892.2122
Epoch710 Batch2 Loss: 4199.7314
Epoch710 Batch4 Loss: 5145.4521
shuffling the dataset
train Loss: 6211.7681
Epoch 711/499
----------
Epoch711 Batch0 Loss: 6750.0991
Epoch711 Batch2 Loss: 4796.5757
Epoch711 

Epoch755 Batch0 Loss: 8668.6592
Epoch755 Batch2 Loss: 9391.1426
shuffling the dataset
Epoch755 Batch4 Loss: 8636.6680
train Loss: 8771.2872
Epoch 756/499
----------
Epoch756 Batch0 Loss: 9747.7490
Epoch756 Batch2 Loss: 11132.3789
shuffling the dataset
Epoch756 Batch4 Loss: 9039.7490
train Loss: 8089.5930
Epoch 757/499
----------
Epoch757 Batch0 Loss: 6957.4893
Epoch757 Batch2 Loss: 12575.8242
Epoch757 Batch4 Loss: 6204.0742
shuffling the dataset
train Loss: 7844.5730
Epoch 758/499
----------
Epoch758 Batch0 Loss: 4643.0957
Epoch758 Batch2 Loss: 7882.3848
Epoch758 Batch4 Loss: 10398.5117
shuffling the dataset
train Loss: 7208.1891
Epoch 759/499
----------
Epoch759 Batch0 Loss: 7092.7407
Epoch759 Batch2 Loss: 11150.9189
Epoch759 Batch4 Loss: 6409.9629
shuffling the dataset
train Loss: 6461.5420
Epoch 760/499
----------
Epoch760 Batch0 Loss: 5018.5933
Epoch760 Batch2 Loss: 9921.5225
Epoch760 Batch4 Loss: 5797.2261
shuffling the dataset
train Loss: 6576.9380
Epoch 761/499
----------
Epoch7

RuntimeError: Unknown error -1

In [None]:


checkpoint_path = opt.logs_path + '/checkpoints/checkpoint-epoch_{}.pth'.format(epoch,i)
print("Initializing weights from: {}...".format(checkpoint_path))
# net.load_state_dict(torch.load(checkpoint_path))  # Load all tensors onto the CPU

# Load all tensors onto the CPU
model.load_state_dict(torch.load(checkpoint_path))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)
  
model.eval()

for ii, sample_batched in enumerate(dataloader.size()/opt.batchSize):
    
    # Get data
    inputs, labels =  dataloader.get_batch()

    # Forward pass of the mini-batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    # Forward Prop
    optimizer.zero_grad()
    torch.set_grad_enabled(True)
    logits = model(inputs)

    predictions = torch.max(logits[:3], 1)[1].detach().cpu().numpy()

#     output_rgb = utils.decode_seg_map_sequence(predictions)

    fig = plt.figure()
    ax0 = plt.subplot(121)
    ax1 = plt.subplot(122)

    output_rgb = output_rgb.detach().cpu().numpy().squeeze(0)
    output_rgb = np.transpose(output_rgb, (1, 2, 0))
    # print(output_rgb.shape)

    ax0.imshow(rgb_img)
    ax0.set_title('Source RGB Image')  # subplot 211 title
    ax1.imshow(output_rgb)
    ax1.set_title('Predicted Normals')

    # plt.show()

    fig.savefig('data/results/%04d-results.png' % (ii))
    plt.close('all')