In [1]:
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import save_image

In [2]:
import sys
sys.path.append('../src/')

from model import UNet
from dataset import SSSDataset
from loss import DiscriminativeLoss

In [3]:
n_sticks = 8

In [4]:
# Model
model = UNet().cuda()

In [5]:
# Dataset for train
train_dataset = SSSDataset(train=True, n_sticks=n_sticks)
train_dataloader = DataLoader(train_dataset, batch_size=1,
                              shuffle=False, num_workers=0, pin_memory=True)

In [6]:
# Loss Function
criterion_disc = DiscriminativeLoss(delta_var=0.5,
                                    delta_dist=1.5,
                                    norm=2,
                                    usegpu=True).cuda()
criterion_ce = nn.CrossEntropyLoss().cuda()

In [7]:
# Optimizer
parameters = model.parameters()
optimizer = optim.SGD(parameters, lr=0.01, momentum=0.9, weight_decay=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 mode='min',
                                                 factor=0.1,
                                                 patience=10,
                                                 verbose=True)

In [8]:
# Train
model_dir = Path('../model')

best_loss = np.inf
for epoch in range(300):
    print(f'epoch : {epoch}')
    disc_losses = []
    ce_losses = []
    for batched in train_dataloader:
        images, sem_labels, ins_labels = batched
        
        nb ,nc, nh, nw = sem_labels.size()
        
#         print('images ', images.numpy().shape)
#         print('sem_labels ', sem_labels.numpy().shape)
#         print('ins_labels ', ins_labels.numpy().shape)
#         save_image(images,'debug_images.png', padding=10)
# #         save_image(torch.from_numpy(sem_labels.numpy()[:,0,:,:]).contiguous(),'debug_sem_labels.png', padding=10)
# #         save_image(ins_labels,'debug_ins_labels.png', padding=10)
#         print(type(sem_labels), sem_labels.size())
#         tmpTensor = sem_labels[:,0,:,:].contiguous()
#         save_image(tmpTensor.view(nb, 1, nh, nw),'debug_sem_labels.png', padding=10)
#         for i in range(8):
#             tmpTensor = ins_labels[:,i,:,:].contiguous()
#             save_image(tmpTensor.view(nb, 1, nh, nw),'debug_ins_labels{}.png'.format(i), padding=10)
        
        
        images = Variable(images).cuda()
        sem_labels = Variable(sem_labels).cuda()
        ins_labels = Variable(ins_labels).cuda()
        model.zero_grad()

        sem_predict, ins_predict = model(images)
        loss = 0
        
#         print('sem_predict ', sem_predict.cpu().data.numpy().shape)
#         #save_image(sem_predict[:,0,:,:],'debug_sem_predict.png', padding=10)
#         tmpTensor = sem_predict.cpu().data
#         print(type(tmpTensor), tmpTensor.size())
#         tmpTensor = tmpTensor[:,0,:,:].contiguous()
#         save_image(tmpTensor.view(nb, 1, nh, nw),'debug_sem_predict.png', padding=10)
#         tmpTensor = ins_predict.cpu().data
#         print(type(tmpTensor), tmpTensor.size())
#         for i in range(16):
#             tmpTensor_a = tmpTensor[:,i,:,:].contiguous()
#             print('ins_predict', tmpTensor_a.view(nb, 1, nh, nw).size())
#             save_image(tmpTensor_a.view(nb, 1, nh, nw),'debug_ins_predict{}.png'.format(i), padding=10)

        # Discriminative Loss
        disc_loss = criterion_disc(ins_predict,
                                   ins_labels,
                                   [n_sticks] * len(images))
        loss += disc_loss
        disc_losses.append(disc_loss.cpu().data.numpy()[0])

        # Cross Entropy Loss
        _, sem_labels_ce = sem_labels.max(1)
        ce_loss = criterion_ce(sem_predict.permute(0, 2, 3, 1)\
                                   .contiguous().view(-1, 2),
                               sem_labels_ce.view(-1))
        loss += ce_loss
        ce_losses.append(ce_loss.cpu().data.numpy()[0])

        loss.backward()
        optimizer.step()
    disc_loss = np.mean(disc_losses)
    ce_loss = np.mean(ce_losses)
    print(f'DiscriminativeLoss: {disc_loss:.4f}')
    print(f'CrossEntropyLoss: {ce_loss:.4f}')
    scheduler.step(disc_loss)
    if disc_loss < best_loss:
        best_loss = disc_loss
        print('Best Model!')
        modelname = 'model.pth'
        torch.save(model.state_dict(), model_dir.joinpath(modelname))

epoch : 0
255 1 1
Variable containing:
 113.1543  213.3946   23.5139  186.1777  141.5761   88.4837  176.2586  217.1681
 568.4436  727.1479  496.0437  482.3612  487.9170  470.5247  802.3741  655.9179
 325.6241  405.1455  264.9794  242.1625  299.2707  245.0710  460.1209  339.9286
-378.4697 -600.9388 -498.8646 -334.0308 -418.8252 -490.8971 -599.8221 -520.0138
 -22.9486  -38.9243   16.1794  -27.7862   36.5483  -39.4086  -67.1051  -88.4310
 -85.9609 -134.7297 -159.6069  -40.0147 -141.4054 -108.9553 -112.4971  -74.8333
 204.7162  341.9136  237.0759  221.7378  310.3222  249.4976  302.2606  331.9651
 198.0648  225.0821  305.8521  108.1960  210.8679  264.3739  258.9245  162.3521
 -57.5268  -71.1029  -89.2489  -29.1808  -51.9062  -22.8865  -86.6508  -55.9651
 268.6915  419.6455  330.0816  251.6573  271.1543  398.3147  499.3630  434.2946
 160.4442  222.8163  159.3813  112.1083   91.2330  116.1224  215.7014  138.3558
 231.7073  410.5033  320.3985  249.2467  458.0750  279.3971  402.7326  400.0425
 

255 1 1
Variable containing:

Columns 0 to 6 
  -98.4580  -613.5593  -192.5880   726.0607   235.3293   148.7852   533.5169
  414.0319   287.0110   330.1857  2060.4971   125.9600   284.0019  1436.1360
  206.0949   121.7750   189.7576   808.5983   -56.6576    60.1902   523.2057
 -265.9784  -196.7387  -320.2847  -868.7852  -255.3296  -316.6601  -641.4170
  -68.3206   527.0132   -60.9855  -585.4313    75.7591   -74.0228  -398.1193
 -108.1389  -111.3251  -152.6685   361.9313  -148.9485  -146.7002   174.3611
  171.6540    46.0288   105.0274    24.8375   330.9738   262.2329   123.3962
  433.8966   871.0914   457.8849   205.2007   286.8817   238.5829   191.5020
  -43.7913  -420.9353   -79.8818   -13.2007    63.9370    68.8503    -4.6477
  298.9371  -202.4396   273.4389   963.8452    23.1922   267.0724   689.1721
  140.5423   169.5857   194.5396   772.2440   -37.3447    21.4661   493.6007
  293.1490   369.9958   273.2545 -1012.0083   572.4057   431.7021  -532.1278
  168.1448    74.2021   276.44

255 1 1
Variable containing:

Columns 0 to 6 
   55.9452    28.6776   413.7400    77.9937   988.9252   681.2457   928.6648
  823.2597  1055.3866   513.4304   793.6485   358.1929  2276.0249   686.4927
  463.4240   553.0206   -91.1199   435.3440  -530.8456  1027.6403  -275.2686
 -423.5439  -517.1088  -104.3162  -335.2118  -411.0370  -798.4686  -237.6600
 -416.4140  -254.8816   705.9547  -138.1076   607.6851  -470.2288   868.8085
   70.5989    46.1253  -154.2623   117.3595  -224.4799   516.9733  -191.9199
   -8.9931   -36.1213   451.9189    21.4133   943.9041  -212.1816   801.4139
  285.2687   414.3919   189.7079   310.9042   217.1074     5.7001   128.5450
  -19.8712  -196.4312  -313.0397  -106.2383   176.1036  -345.7509  -199.9810
  665.6089   573.3804  -554.3378   366.5082  -533.7722   942.2537  -732.7529
  413.6979   477.6137  -420.9854   342.8408  -696.2310   857.9987  -744.9888
 -174.9926  -187.3975   853.1619  -105.6462  1295.7811 -1154.1282  1291.6106
   85.6156   -73.0759  -831.34

255 1 1
Variable containing:

Columns 0 to 6 
  396.1710   611.1769   435.3429   725.6790   754.0438   199.4838   745.1205
  120.5427  2705.8232   875.3053  2013.8604  2337.9360  1227.1705   652.2244
 -256.8752  1439.5217   131.4101   467.1956   624.6756   728.3425   -63.8545
 -266.0807  -761.4582  -284.4030  -435.8813  -502.0706  -370.5451   -71.6034
  280.0260  -522.7799   365.8985   677.1016   711.4850  -396.2187   595.4503
 -198.7868   726.4229    -9.2581    98.5694   175.3111   332.6790  -272.3020
  504.8878  -306.2860   240.0765   198.0072   121.3861  -115.8214   811.9730
  276.8186   229.3148   151.0125    16.9267    41.7604   194.7257   215.3577
   48.6790  -479.6502  -272.1464  -768.0482  -856.7922  -139.2450  -370.3779
 -236.8884  1205.4045  -122.0471  -235.4095  -144.1418   724.9384  -619.8795
 -227.3900  1187.5631   -54.4928   -61.2810    62.0183   637.5989  -555.7264
  765.9144 -1407.0665   267.5814   121.7832   -54.1791  -568.3554  1130.2737
 -123.1815 -1221.7284  -755.77

255 1 1
Variable containing:

Columns 0 to 6 
  337.5161   465.2321   735.4440   294.7927   347.9738   400.3153   319.4188
  839.4775    98.7566  1735.5331  1682.1326    -1.4022   589.5657  1900.1670
  168.4560  -180.5479    95.3823   982.7570  -145.6494   -97.8465  1140.3545
 -190.8274  -393.8274   -38.4917  -394.4087  -426.6472    21.1714  -439.6925
  308.3678  -111.3615  1368.1936  -104.1933  -171.0936   826.7453  -156.5803
   14.4602   -16.4182  -208.1116   532.7056   -14.8247  -237.9615   610.8867
  137.6555   417.4320   725.7543  -209.3586   363.1245   615.7484  -227.2582
   47.9891   124.5896   184.7856   280.8046   162.0380   275.1824   336.8415
 -321.6635   420.1562 -1188.0458  -282.1563   456.8793  -583.6187  -307.6616
 -106.3468   265.8785 -1191.5878   635.1953   374.6920  -798.7099   774.2295
  -44.0421   -85.0256  -703.2303   708.3378    -6.0413  -440.8255   848.6038
  122.1407   569.6532  1074.8484  -697.4354   590.5095   980.4251  -800.8754
 -822.5114   294.2144 -2923.31


<class 'torch.autograd.variable.Variable'>
<class 'torch.autograd.variable.Variable'>
torch.Size([16, 8])
torch.Size([1, 8])
255 1 1
Variable containing:

Columns 0 to 6 
  196.7741   490.4744   -86.2780   129.1376   555.0151   748.8704   307.9070
    7.4148  1034.8851   767.8452  1127.0707    51.4518  2978.2017   657.2844
 -129.4886   -99.9757   749.5443   776.6872  -349.1742  1342.6334    -4.9227
 -288.4580    57.6165  -200.9436  -285.5578  -119.6301  -738.5024   -17.9004
  -59.0700  1326.6455  -389.7232  -211.6805   463.9432   137.5943   682.2621
 -130.8335  -189.3111   401.9773   396.3431  -200.2716   765.8528  -109.5163
  317.5870   673.2683  -161.6349  -123.1887   657.9589  -288.3796   377.2148
  153.5158   306.5863   278.3462   271.8118   175.0311   334.4355   183.9011
  208.7179  -919.3558    85.9551   -87.1206    -3.9411  -638.0652  -514.0826
  193.8032 -1187.9894   758.8290   642.1474  -492.7916   671.3285  -552.4409
  -55.5533  -595.9908   684.2703   617.2544  -338.1662   8

255 1 1
Variable containing:

Columns 0 to 6 
  167.7304   272.9390  -234.2400   668.8321  -151.9034   785.8852   289.8893
  166.6822   706.3422   -14.6953  2149.8845    37.5932  2502.8940   906.0951
  -83.5613  -185.3580   132.2996   856.9852   114.9514   962.7068  -335.4731
  -13.8957    97.0751  -495.5301  -532.7541  -437.2250  -597.5565   209.0738
  295.3774  1041.0712  -439.1723   120.1249  -377.6965   209.9661  1595.6891
 -154.5919  -343.1642  -158.6290   342.0199  -131.5273   415.5123  -495.1954
  302.5197   558.2788     0.6345   -83.4393    27.2599  -119.6705   782.8250
  142.5157   275.2891   219.6167   243.3379   186.7385   262.0961   413.6455
 -174.8399  -778.9952   331.2618  -531.1924   288.5158  -650.9689 -1139.8813
 -265.2490  -973.5486   742.5792   398.4292   650.3720   371.5780 -1568.1022
 -184.2586  -656.4518   284.1628   536.1373   239.7859   572.3677 -1017.6389
  512.9831   993.5938   487.1666  -692.2027   430.1934  -779.1763  1477.8828
 -411.7678 -1715.0530   858.90

Variable containing:

Columns 0 to 6 
  204.2098   593.0823    81.3521   354.7387   406.3262   477.4181   683.4144
  606.7121  2251.0156   734.3335   770.7831   802.3414  1910.8649  2386.8672
  203.7361  1114.4860   402.5667  -109.1330   -85.6452   971.6437  1046.0088
 -138.8481  -461.6863  -319.0710   146.8848   123.7225  -371.2577  -531.6763
   74.3108  -163.7204  -274.8342  1020.3188  1001.0073  -109.4152   -37.2240
   61.4444   512.5599    25.2991  -280.9323  -293.8519   453.7287   456.2823
   16.5237  -218.8597    -1.7823   504.9534   539.4387  -145.7561  -180.1248
   54.7231   256.8807   165.0961   224.4500   248.0148   209.5957   269.2815
 -179.2662  -528.5588   -12.6488  -813.9232  -813.6378  -438.5723  -589.6038
   54.6619   650.1656   566.3260  -977.8531  -940.2483   560.6938   550.2983
   92.8278   860.5999   344.1914  -691.1537  -667.1974   739.0222   738.7039
 -133.6296 -1166.6007   -31.2967   872.8571   900.1332  -874.9921 -1059.5710
 -435.2010 -1213.3469   -71.9455 -1873

Variable containing:

Columns 0 to 6 
  251.4498   485.5776   295.6325   672.3168   440.2427   320.8432   619.5449
  609.6792  1114.1392   603.0832   426.1001  1433.8698     5.8491   606.6374
  124.9603    44.6337   -15.2140  -229.0209   529.6334  -134.4649  -155.1796
  -70.2530   103.7800    76.7303   216.5640  -173.5973   -88.6404   282.1624
  258.0211  1009.4044   648.7599  1035.9957   250.9656    -8.0891  1158.3938
  -50.2138  -198.2559  -181.5802  -306.8568   173.7793   -53.7414  -356.6397
  161.3777   495.0469   360.4552   860.6257    84.1919   357.7836   826.3333
   86.3390   151.9057   122.9034   156.5267   164.3273    36.0407   200.1941
 -317.1192  -916.6213  -571.9462  -588.8260  -530.3279   222.8566  -816.3497
 -125.0553  -911.0378  -565.6049 -1105.1068    33.7732   113.4057 -1209.3497
  -73.2286  -575.3566  -391.8037  -752.8163   217.3738   -72.3321  -816.5055
  125.9182   674.2443   526.8556  1331.4484  -272.2292   588.1929  1312.6537
 -678.1979 -2069.7129 -1216.5869 -1640

Variable containing:

Columns 0 to 6 
  657.1694  -222.7955   193.3928   330.1187   579.0648   787.6154   640.1694
 1599.7029   -10.7726   317.3881   220.0840  1048.2208  3041.3936   540.9370
  108.2922   138.6476    34.9339   -68.5819    50.2428  1244.3678  -121.5579
  158.5672  -253.6438    18.4990    93.3231   167.7024  -545.3413   247.5108
 1379.8235  -506.8713   269.2483   408.6705  1052.3640  -135.0520   941.2961
 -278.9041   118.4769   -59.5971   -94.9886  -266.9944   486.2333  -241.3238
  650.0387   -62.0317   167.8344   366.5941   601.5248  -196.9841   746.2449
  174.6720    88.1259    49.6426    53.9537   161.3818   290.7985   119.6054
-1312.9797   517.6514  -231.7355  -221.3812  -947.0645  -919.1480  -617.4553
-1314.6663   755.6858  -229.7061  -407.6480 -1030.1534   602.9301 -1000.7103
 -763.2226   421.3639  -129.4597  -273.6104  -609.6007   855.2313  -640.4717
  820.4511   149.7247   249.7815   607.3224   843.3090 -1471.3204  1191.0876
-3043.4431   674.1345  -578.1814  -764

255 1 1
Variable containing:

Columns 0 to 6 
 1044.9391   271.1653   104.9754  -105.4989   743.1046   903.3455    55.2625
 3459.4978   492.1868   535.7291   396.9809  -668.5988  -696.8449   426.2468
  804.8014  -119.9810   311.8995   429.0894  -638.4160  -905.5529   382.1301
 -492.6746   252.4534   -67.6453  -293.3504   119.9013   515.6802  -248.5043
  842.7463   725.7411   -86.5887  -545.2386   297.4142   944.2100  -454.4743
  315.5906  -142.7637   128.2936   181.4345    32.4016  -112.2639   223.5101
   31.7651   387.4140    14.1285   -88.4961   865.7384  1300.7145    -1.1289
   39.1937    55.0833   103.0273   165.8742  -180.5533  -230.6387   133.4683
-1419.6182  -575.4133   -88.0755   304.9966   557.9472   179.5761   281.1489
 -371.5999  -695.4359   299.1644   878.4719  -389.9541 -1194.1215   766.4474
   73.0485  -448.1352   277.1741   550.0552  -421.2105  -936.0724   510.0326
 -706.4119   735.2377   -62.6144   -64.4052  1381.4752  2224.9299   -32.6334
-3269.5996 -1411.2081  -277.30


Variable containing:
 1247   988  1526  1546  1011  1255  1440  1486
[torch.cuda.FloatTensor of size 1x8 (GPU 0)]

<class 'torch.autograd.variable.Variable'>
<class 'torch.autograd.variable.Variable'>
torch.Size([16, 8])
torch.Size([1, 8])
255 1 1
Variable containing:

Columns 0 to 6 
  762.3750   212.7367   507.6519   975.6836   475.2646   490.3456    79.0996
 -280.4875   712.4417  1958.8002  -380.5097  1667.8589   116.4630   268.8171
 -721.7579   166.4756   369.5688  -895.9913   259.2737  -391.5500   192.5831
  370.5767  -123.4257  -345.3531   414.1003  -296.8076   191.7312  -304.3342
  889.7243    57.8071   238.2541  1051.2781   274.4935   633.7037  -367.7476
 -196.2696    67.7622   225.5972  -252.5600   114.5945  -177.2887   121.5916
  988.6514    26.3845   -97.9730  1230.5586   -26.3609   623.2776    87.6337
 -200.5684    92.3880   103.4955  -260.0681   102.8961   -28.7440   108.6309
 -108.7131  -198.0139  -617.7900   -68.2853  -562.4921  -227.7195   331.0143
-1044.8555   135.196


Variable containing:
 1045  1006  1376  1272  1416  1218  1140  1330
[torch.cuda.FloatTensor of size 1x8 (GPU 0)]

<class 'torch.autograd.variable.Variable'>
<class 'torch.autograd.variable.Variable'>
torch.Size([16, 8])
torch.Size([1, 8])
255 1 1
Variable containing:

Columns 0 to 6 
  535.1999    50.4358   649.2742   736.3322   835.9279  -261.9423   911.4847
 1913.7277   332.4466  2052.0435  -129.3724  -278.1168    98.5644  -296.2423
  453.4724   228.2730   264.8547  -894.1339  -647.0229   422.6222  -742.4559
 -376.5466   -78.2590  -383.7970   565.6333   126.1670  -302.2430   184.7751
  -46.1927  -131.2345   261.4382  1369.6544   599.3101  -525.5375   741.4717
  245.8584   108.5128   152.5696  -416.6130  -149.2502   187.5701  -192.7737
 -146.6704     3.9607   -54.1142  1059.3099   894.7202   -96.2177   991.1222
  164.7069   109.2641   100.1674  -100.3686  -192.3738   214.4080  -217.5597
 -512.6095    14.9048  -646.5325  -601.8485   147.4893   439.2295    78.7709
  390.6459   307.465

KeyboardInterrupt: 