In [1]:
import sys
sys.path.append('./models/')
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from data_loader import Dataset
import models.unet_normals as unet
from tensorboardX import SummaryWriter
# import OpenEXR, Imath
import imageio
from torchvision.utils import make_grid


### Setup Options
Set the various parameters:
- dataroot: The folder where the training data is stored
- file_list: List of filenames of images for training
- batchSize: Batch size for model
- shuffle: If true, will shuffle the dataset
- phase: If 'train', then it's in training mode.
- num_epochs: Number of epochs to train the model for
- imsize: Dimensions of the image (square)
- num_classes: Num of classes in the output
- gpu: Which GPU device to use
- logs_path: The path where the log files (tensorboard) will be saved.

In [2]:
class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.batchSize = 16
        self.shuffle = False
        self.phase = 'train'
        self.num_epochs = 500
        self.imsize = (224,224)
        self.num_classes = int(2)
        self.gpu = '0'
        self.logs_path = 'logs/exp3'
        self.use_pretrained = False

opt = OPT()

### Setup logging and dataloaders

In [3]:
###################### Options #############################
phase = opt.phase
device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")

###################### TensorBoardX #############################
if os.path.exists(opt.logs_path):
    raise Exception('The folder \"{}\" already exists! Define a new log path or delete old contents.'.format(opt.logs_path))
    
writer = SummaryWriter(opt.logs_path, comment='create-graph')
graph_created = False

###################### DataLoader #############################
dataloader = Dataset(opt)


### Create the model
We use a UNet model. The last few layers of this model are modified to return a 3 channel image, containing the x,y,z values of surface normal vectors.

In [4]:
###################### ModelBuilder #############################
model = unet.Unet(num_classes=opt.num_classes)
criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

# Load weights from checkpoint
if (opt.use_pretrained == True):
    checkpoint_path = 'logs/exp1/checkpoints/checkpoint-epoch_405.pth'
    model.load_state_dict(torch.load(checkpoint_path))

model = model.to(device)
model.train()

###################### Setup Optimazation #############################
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

###################### Loss fuction #############################
def flatten_logits(logits, number_of_classes):
    """Flattens the logits batch except for the logits dimension"""
    logits_permuted = logits.permute(0, 2, 3, 1)
    logits_permuted_cont = logits_permuted.contiguous()
    logits_flatten = logits_permuted_cont.view(-1, number_of_classes)
    return logits_flatten

def flatten_annotations(annotations):
    return annotations.view(-1)

def get_valid_annotations_index(flatten_annotations, mask_out_value=255):
    return torch.squeeze( torch.nonzero((flatten_annotations != mask_out_value )), 1)

### Train the model


In [5]:
###################### Train Model #############################
# Calculate total iter_num
total_iter_num = 0

for epoch in range(opt.num_epochs):#411, 411+opt.num_epochs):
    print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    running_loss = 0.0
    
    
    # Iterate over data.
    for i in range(int(dataloader.size()/opt.batchSize)):
        total_iter_num += 1
        
        # Get data
        inputs, labels =  dataloader.get_batch()
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # We need to flatten annotations and logits to apply index of valid annotations.
        anno_flatten = flatten_annotations(labels)
        index = get_valid_annotations_index(anno_flatten, mask_out_value=255)
        anno_flatten_valid = torch.index_select(anno_flatten, 0, index)
        
        ## Create Graph ##
        if graph_created == False:
            graph_created = True
            writer.add_graph(model, inputs, verbose=False)
        
        # Forward Prop
        optimizer.zero_grad()
        torch.set_grad_enabled(True)
        logits = model(inputs)
        
        # Calculate Loss
        logits_flatten = flatten_logits(logits, number_of_classes=opt.num_classes)
        logits_flatten_valid = torch.index_select(logits_flatten, 0, index)
        loss = criterion(logits_flatten_valid, anno_flatten_valid)
        
        # Backward Prop
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item()
        writer.add_scalar('loss', loss.item(), total_iter_num)

        
        # Print image every N epochs
        nTestInterval = 1
        if (epoch % nTestInterval) == 0:
            img_tensor = inputs[:3].clone().cpu()
            
            output_tensor = torch.unsqueeze(torch.max(logits[:3], 1)[1].detach().cpu().float(), 1)
            output_tensor = torch.cat((output_tensor, output_tensor, output_tensor), 1)
            
            label_tensor = labels[:3].detach().cpu().float()
            label_tensor = torch.cat((label_tensor, label_tensor, label_tensor), 1)
            
            images = []
            for img, output, label in zip(img_tensor, output_tensor, label_tensor):
                images.append(img)
                images.append(output)
                images.append(label)

            grid_image = make_grid(images, 3, normalize=True, scale_each=False )
            writer.add_image('Train', grid_image, epoch)
        
        if (i % 2 == 0):
            print('Epoch{} Batch{} Loss: {:.4f}'.format(epoch, i, loss.item()))

    epoch_loss = running_loss / (dataloader.size()/opt.batchSize)
    writer.add_scalar('epoch_loss', epoch_loss, epoch)
    print('{} Loss: {:.4f}'.format(phase, epoch_loss))
    
    # Save the model checkpoint
    directory = opt.logs_path+'/checkpoints/'
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    if (epoch % 5 == 0):
        filename = opt.logs_path + '/checkpoints/checkpoint-epoch_{}.pth'.format(epoch,i)
        torch.save(model.state_dict(), filename)
        

  

Epoch 0/499
----------


RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/THC/generic/THCStorage.cpp:36