In [1]:
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import save_image

In [2]:
import sys
sys.path.append('../src/')

from model import UNet
from dataset import SSSDataset
from loss import DiscriminativeLoss

In [3]:
n_sticks = 8

In [4]:
# Model
model = UNet().cuda()

In [5]:
# Dataset for train
train_dataset = SSSDataset(train=True, n_sticks=n_sticks)
train_dataloader = DataLoader(train_dataset, batch_size=1,
                              shuffle=False, num_workers=0, pin_memory=True)

In [6]:
# Loss Function
criterion_disc = DiscriminativeLoss(delta_var=0.5,
                                    delta_dist=1.5,
                                    norm=2,
                                    usegpu=True).cuda()
criterion_ce = nn.CrossEntropyLoss().cuda()

In [7]:
# Optimizer
parameters = model.parameters()
optimizer = optim.SGD(parameters, lr=0.01, momentum=0.9, weight_decay=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 mode='min',
                                                 factor=0.1,
                                                 patience=10,
                                                 verbose=True)

In [8]:
# Train
model_dir = Path('../model')

best_loss = np.inf
for epoch in range(300):
    print(f'epoch : {epoch}')
    disc_losses = []
    ce_losses = []
    for batched in train_dataloader:
        images, sem_labels, ins_labels = batched
        
        nb ,nc, nh, nw = sem_labels.size()
        
#         print('images ', images.numpy().shape)
#         print('sem_labels ', sem_labels.numpy().shape)
#         print('ins_labels ', ins_labels.numpy().shape)
#         save_image(images,'debug_images.png', padding=10)
# #         save_image(torch.from_numpy(sem_labels.numpy()[:,0,:,:]).contiguous(),'debug_sem_labels.png', padding=10)
# #         save_image(ins_labels,'debug_ins_labels.png', padding=10)
#         print(type(sem_labels), sem_labels.size())
#         tmpTensor = sem_labels[:,0,:,:].contiguous()
#         save_image(tmpTensor.view(nb, 1, nh, nw),'debug_sem_labels.png', padding=10)
#         for i in range(8):
#             tmpTensor = ins_labels[:,i,:,:].contiguous()
#             save_image(tmpTensor.view(nb, 1, nh, nw),'debug_ins_labels{}.png'.format(i), padding=10)
        
        
        images = Variable(images).cuda()
        sem_labels = Variable(sem_labels).cuda()
        ins_labels = Variable(ins_labels).cuda()
        model.zero_grad()

        sem_predict, ins_predict = model(images)
        loss = 0
        
#         print('sem_predict ', sem_predict.cpu().data.numpy().shape)
#         #save_image(sem_predict[:,0,:,:],'debug_sem_predict.png', padding=10)
#         tmpTensor = sem_predict.cpu().data
#         print(type(tmpTensor), tmpTensor.size())
#         tmpTensor = tmpTensor[:,0,:,:].contiguous()
#         save_image(tmpTensor.view(nb, 1, nh, nw),'debug_sem_predict.png', padding=10)
#         tmpTensor = ins_predict.cpu().data
#         print(type(tmpTensor), tmpTensor.size())
#         for i in range(16):
#             tmpTensor_a = tmpTensor[:,i,:,:].contiguous()
#             print('ins_predict', tmpTensor_a.view(nb, 1, nh, nw).size())
#             save_image(tmpTensor_a.view(nb, 1, nh, nw),'debug_ins_predict{}.png'.format(i), padding=10)

        # Discriminative Loss
        disc_loss = criterion_disc(ins_predict,
                                   ins_labels,
                                   [n_sticks] * len(images))
        loss += disc_loss
        disc_losses.append(disc_loss.cpu().data.numpy()[0])

        # Cross Entropy Loss
        _, sem_labels_ce = sem_labels.max(1)
        ce_loss = criterion_ce(sem_predict.permute(0, 2, 3, 1)\
                                   .contiguous().view(-1, 2),
                               sem_labels_ce.view(-1))
        loss += ce_loss
        ce_losses.append(ce_loss.cpu().data.numpy()[0])

        loss.backward()
        optimizer.step()
    disc_loss = np.mean(disc_losses)
    ce_loss = np.mean(ce_losses)
    print(f'DiscriminativeLoss: {disc_loss:.4f}')
    print(f'CrossEntropyLoss: {ce_loss:.4f}')
    scheduler.step(disc_loss)
    if disc_loss < best_loss:
        best_loss = disc_loss
        print('Best Model!')
        modelname = 'model.pth'
        torch.save(model.state_dict(), model_dir.joinpath(modelname))

epoch : 0
img <class 'torch.FloatTensor'>
img <class 'torch.FloatTensor'>
img <class 'torch.FloatTensor'>
img <class 'torch.FloatTensor'>
Variable containing:

Columns 0 to 6 
  289.1282   340.1148   372.1285   317.8081   372.1449   286.3738   530.7367
 -324.3894  -499.7675  -411.0128  -289.9552  -401.6709  -369.9888  -514.5149
 -289.9343  -518.9856  -300.0668  -245.7392  -263.9046  -545.4254  -468.9028
 -295.9814  -463.8456  -353.8270  -269.4973  -300.6821  -461.2397  -454.7359
  286.6226   582.5879   208.6590   268.9242   293.7649   584.9517   447.5388
 -328.2320  -732.3148  -453.0490  -330.2997  -402.1884  -509.5753  -576.3776
  -35.4079   -62.8584   -88.7989   -29.6020   -67.4757   -95.4357  -137.2250
  247.6983   400.9331   247.5770   260.8599   378.2213   427.8092   480.2264
 -439.2777  -650.9365  -437.8488  -394.9053  -424.3885  -614.9483  -764.9960
 -370.1022  -639.5310  -417.9650  -389.3400  -487.3299  -735.7640  -772.9658
  216.9915   455.1708   177.0734   215.8366   259.1003

Variable containing:

Columns 0 to 6 
  438.9550  1377.1028  1572.6779   112.7945   211.6793   162.0668   649.8199
 -630.5518 -2589.1506 -2933.7253   -34.5929  -414.3949  -226.3064 -1042.4166
 -126.4327  -641.5596  -743.9536  -202.2277  -140.6494  -281.8228  -337.2015
 -397.9798  -888.5202 -1066.7419  -124.9168  -120.5751  -221.4184  -445.2255
  220.3957  -503.2359  -660.0302   429.4461   -15.2550   556.6344  -379.0162
 -554.0085 -1492.8199 -1584.6968  -303.8248  -447.7265  -477.0030  -808.8481
 -238.5503  -521.5760  -617.5172   260.1634    81.5956    20.5767   -25.1520
  415.0131   666.1862   734.2755    38.9789   126.9678   403.4213    91.0016
 -575.9492 -1075.5870 -1225.9556  -223.9177  -100.7590  -420.3942  -198.8433
 -801.9533 -1133.1392 -1208.0857    66.6105  -168.2590  -580.8151  -145.6729
  112.3091  -356.7528  -454.0132   213.5525    83.4890   364.9939   -85.3909
  109.5673  -757.5605  -898.1311  -294.8457   -88.2081  -136.3251  -382.5033
  267.8368  -922.8942 -1117.2452   277

KeyboardInterrupt: 