In [1]:
import pandas as pd
import numpy as np

from tqdm import tqdm_notebook

import cv2
import os

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
TRAIN_DIR = './input/stage1_train/'

train_ids = os.listdir(TRAIN_DIR)
train_images = [os.path.join(TRAIN_DIR, train_id, 'images', train_id + '.png') 
                for train_id in train_ids]
train_masks = {train_id: [os.path.join(TRAIN_DIR, train_id, 'masks', img_name) 
                          for img_name in os.listdir(os.path.join(TRAIN_DIR, train_id, 'masks'))]
               for train_id in train_ids}

In [3]:
X = {train_ids[i]: cv2.imread(train_images[i]) for i in range(len(train_images))}

Y = {train_id: sum((cv2.imread(train_mask)[..., 0]
                    for train_mask in train_masks[train_id]))
     for train_id in train_ids}

In [5]:
model = nn.Sequential(nn.Conv2d(3, 16, (11, 11), padding=5),
                      nn.ReLU(),
                      nn.Conv2d(16, 16, (5, 5), padding=2),
                      nn.ReLU(),
                      nn.Conv2d(16, 1, (5, 5), padding=2),
                      nn.Sigmoid())

In [6]:
N_EPOCHS = 10
BATCH_SIZE = 16

optimizer = Adam(model.parameters())

for epoch in range(N_EPOCHS):
    np.random.shuffle(train_ids)
    i = 0
    avg_loss = 0
    optimizer.zero_grad()
    for tr_id in tqdm_notebook(train_ids):
        batch_x = np.expand_dims(np.swapaxes(X[tr_id], 0, 2), 0) / 255.0
        batch_x = Variable(torch.FloatTensor(batch_x))

        batch_y = np.expand_dims(np.swapaxes(Y[tr_id], 0, 1), 0) / 255.0
        batch_y = Variable(torch.FloatTensor(batch_y))
        
        prediction = model(batch_x)[:, 0]

        loss = F.binary_cross_entropy(prediction, batch_y)
        avg_loss = 0.9 * avg_loss + 0.1 * loss.data[0]
        loss = loss / BATCH_SIZE
        loss.backward()
        
        avg_loss = 0.9 * avg_loss + 0.1 * loss.data[0]
        
        if i % BATCH_SIZE == BATCH_SIZE - 1:
            print(avg_loss)
            optimizer.step()
            
            i = -1
            optimizer.zero_grad()
        i += 1

0.33649299177419645
0.34527859320100657
0.34303715523959216
0.34575391509940134
0.34070421815459306
0.3319864692918182
0.30075458865610166
0.3310741563929824
0.3263619389966045
0.2935727722608237
0.3631558761022554
0.35533498400739644
0.3130450306167215
0.25869480742549134
0.3015571986291396
0.30963385482259936
0.32305473692575126
0.3194843171350899
0.2830606774992771
0.24206425109204419
0.24186673388287167
0.2813719182330518
0.3190010611011181
0.26500429445333473
0.2593272992581935
0.2643933605692287
0.26160719953431394
0.28956509364015326
0.29595193388922103
0.28316461367940354
0.28267051444701985
0.29426964147311235
0.24987605421085493
0.28443574518348497
0.2653919638215132
0.31030746544786847
0.28951448877820857
0.26365884250493743
0.23561612349388297
0.23343231511060117
0.22407950603534338



0.2239817205718239
0.22067991590633332
0.22668618945453045
0.18628831774193608
0.21055650745198312
0.20121853282878618
0.20391817588510106
0.20904460717061724
0.1846415860495694
0.22018393243639442
0.17136572008141038
0.19467791734214335
0.23167073222038598
0.19014662885388894
0.1797966937755135
0.17698371351463613
0.20498611901556182
0.2825440318831344
0.1671056419573791
0.18480016967173474
0.2177941839730656
0.23388875050140578
0.12335876030460749
0.10659791679314697
0.25118475905630916
0.17122866864641076
0.18910011511828193
0.1851166116640844
0.1596106043488506
0.22067349006486922
0.23092362516479692
0.1706127926901736
0.10673041227443703
0.1558930155461929
0.18923752699335847
0.12914258496986106
0.12718194087463378
0.14092122966974238
0.12367510891949739
0.12209160945710189
0.1110777127153104



0.12232086669531811
0.10880718488614655
0.1673109840724804
0.15757967462550102
0.13759680025895402
0.17550148983831163
0.20567436622686577
0.13328043625672487
0.12195146845537885
0.1775629876933664
0.15643929743272936
0.1189176871312779
0.1865502616644357
0.11554608917806308
0.18249009559348944
0.15762527275386048
0.11949066258127522
0.14961486895638748
0.12831239053805393
0.14797115723182747
0.17245084659128215
0.13987278985061885
0.14510549714638157
0.15458822721821147
0.09808507322532281
0.1397531309162783
0.17389609022752212
0.11266920117803743
0.13551167672371306
0.10594188996602484
0.11167889940217489
0.12476686224871926
0.13213489096036304
0.09863943880655691
0.14286499750786824
0.12969341236668988
0.07821238635663338
0.12733297262445772
0.10979218396718474
0.14715739889705673
0.14481058482119744



0.10892586732959697
0.15863498913054405
0.14600537667265875
0.143059352829299
0.0956130243074286
0.14591322068778664
0.1433320278155207
0.08686804427233066
0.13432243706424296
0.10373546209780658
0.1546825808236643
0.12279690188447678
0.1372021310236845
0.14786518038050722
0.09672149600908692
0.12297885578962744
0.12946702792167641
0.12415581793230442
0.14559983528992598
0.12079182755590503
0.13849663828616052
0.16203263586083766
0.13762860964035137
0.09665050158608342
0.0657436175850814
0.10328263929295507
0.1518229873138989
0.07087786256861296
0.1799699189654416
0.10891682440427498
0.1304513822564989
0.11103109053370017
0.09465364275737732
0.13681575589492972
0.08538183527002899
0.13927071609600644
0.12342549919961376
0.12124347361776129
0.08222506112863284
0.10319956878511802
0.07124621742527494



0.14887859030061448
0.11572197021385422
0.15337792893838995
0.15240308628383153
0.12296642119862561
0.12443949623832977
0.12541295161817775
0.10951756750300548
0.09418728909604394
0.10126920604013072
0.09068321982427091
0.09262387771790893
0.1339644402074862
0.08188060345189249
0.09158818387864993
0.09865762935993524
0.11005423247141184
0.1388091835833927
0.13164077230422364
0.09841732346600535
0.08606686148467983
0.11227273916996344
0.09264404999959022
0.057337743488122164
0.08115611740133852
0.06305718862782997
0.09955257664812206
0.1095260178930984
0.16543645022165743
0.10168248598359615
0.0768276000425373
0.11446139763215274
0.08720504129374614
0.12419489946707342
0.10495183982642858
0.12281930231912948
0.10302869087773349
0.10884328175831742
0.11597690018730326
0.0886973289407946
0.06790550490668917



0.06606018206977494
0.14281892508435454
0.08422617446116715
0.0729325197362287
0.13750193940490293
0.09298655284796849
0.06411107236539576
0.06912109538670774
0.12535892515663696
0.11011522921096464
0.08703354130676728
0.06132172951160989
0.1159475490080072
0.08528703450630772
0.08652523407255497
0.10954304158547948
0.06407650281870797
0.0963323832552488
0.1075581318954375
0.06343988620553634
0.09946566287586929
0.09832417624988635
0.08919968977419847
0.14223794232597053
0.08998949300644679
0.09477481755025728
0.09873099023620277
0.10755008327666125
0.09883075825805411
0.08656500110485378
0.0993528225473879
0.06678498060560195
0.05193108486979946
0.08467552760738979
0.048565587612762466
0.08690506227489754



0.12168045067618971
0.05488528938897094
0.05725915644727407
0.10275906043419578
0.07940537699030763
0.09701717903761636
0.11365174898558864
0.06575575662216938
0.10286580832063862
0.061846255498587126
0.08413733380781349
0.07622993840278362
0.0744344943241229
0.07815936012441348
0.10299329796488203
0.07620654614246412
0.08374605416264967
0.10281077610016753
0.09474020236917079
0.10251046885086812
0.09433065651618695
0.057929758422814594
0.0831585178415865
0.13890221625061328
0.06821383107064222
0.08033181912227129
0.10404563613747202
0.12642586708595438
0.1393179062195444
0.08855529487238563
0.1126990150879555
0.08926597397471575
0.08826107377307721
0.09717820540461278
0.10051831136707481
0.0750862261244636
0.1097324228691528
0.09009934676896798
0.10978165648392589
0.06401897087486301
0.1009970707001931



0.08357901180167143
0.07247114435564528
0.06224076359803361
0.05249245296855942
0.09181025552175286
0.14231607192297707
0.06832258395539313
0.09626080721574384
0.06045929517947184
0.0632990077806847
0.03945586351194853
0.1344468213695893
0.06335599710945691
0.12658980076408818
0.08622497257131652
0.061644698594185046
0.07122563153322882
0.09866244066450662
0.1499226360148527
0.15666638482873424
0.06943832838366182
0.13670923961429693
0.12061820994710365
0.11244275321272533
0.04883570134650257
0.11459870182999989
0.11147077128360217
0.08515346504903863
0.0978098696124916
0.10614554092918516
0.07199640277709933
0.11004096142866375
0.08245893850991412
0.08720526963452109
0.07910026214317102
0.14232312372882203
0.09372253808590733
0.07964076970637265
0.07745066210768244
0.053511974908183685
0.054798997483502025



0.124147644106556
0.09490893016265982
0.0569077227295865
0.059955430872683464
0.05236616218751645
0.07712061244944161
0.07117941492797154
0.08080312470274982
0.13080482406559452
0.06495366033621439
0.07619052146570542
0.10659134979429596
0.12878143982925205
0.0896860576859776
0.10887476797539641
0.10830600404920468
0.08584752091355405
0.03304178915657924
0.0833973779335188
0.10603674596674299
0.07825256417945045
0.09004693602639666
0.04506585035806972
0.11802825457572998
0.09162726583547788
0.04298474752201085
0.039783895897110046
0.0664280303028081
0.07836930883108019
0.08530206743678918
0.09070580799237825
0.09587050825363816
0.07267602134726547
0.059924220696327156
0.065719839652996
0.059376526179193125
0.08432578231462036
0.03865918537428516
0.06198820608498417
0.10828087097260156
0.07373178434578287



0.06306534515621899
0.05477894820952202
0.050857466469354067
0.07442223762987628
0.11105570882697981
0.06577343471359154
0.08122752555134599
0.08821137311867273
0.10543841367285706
0.091960672706543
0.047429106814413216
0.07811065451194027
0.09850941549906538
0.09646538460520299
0.09856143875425144
0.08577306027121938
0.06628513680742708
0.09026306651431729
0.11220031734741322
0.09071184670558198
0.07873770100563494
0.13143548265967694
0.0908983478615024
0.07847448562778248
0.08687883703689814
0.07311095375540148
0.07458760050977511
0.08503344209522391
0.0769939827830356
0.08072471882769874
0.06774100051822309
0.05084722846999832
0.07646355150878124
0.055357759686256516
0.06535085689900354
0.1211896209388091
0.10183022586062143
0.055688487516671756
0.11190753197337586
0.06358323909624185
0.09133877292842318



In [7]:
TEST_DIR = './input/stage1_test/'
test_ids = os.listdir(TEST_DIR)
test_images = [os.path.join(TEST_DIR, test_id, 'images', test_id + '.png') 
                for test_id in test_ids]

X_test = {test_ids[i]: cv2.imread(test_images[i]) for i in range(len(test_images))}

In [8]:
test_id = test_ids[0]

batch_x = np.expand_dims(np.swapaxes(X_test[test_ids[0]], 0, 2), 0) / 255.0
batch_x = Variable(torch.FloatTensor(batch_x))

In [9]:
from skimage.morphology import label

def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def prob_to_rles(x, cutoff=0.5):
    lab_img = label(x > cutoff)
    for i in range(1, lab_img.max() + 1):
        yield rle_encoding(lab_img == i)

In [10]:
image_ids = []
rles = []

for test_id in tqdm_notebook(test_ids):
    batch_x = np.expand_dims(np.swapaxes(X_test[test_id], 0, 2), 0) / 255.0
    batch_x = Variable(torch.FloatTensor(batch_x))
    
    prediction = model(batch_x)[0, 0].data.numpy()
    
    for rle in prob_to_rles(prediction):
        image_ids.append(test_id)
        rles.append(rle)




In [11]:
submission = pd.DataFrame(data={'ImageId': image_ids,
                                'EncodedPixels': [' '.join(map(str, x)) for x in rles]})

In [12]:
submission.to_csv('./0.csv', index=None)

In [13]:
from IPython.display import FileLink

In [14]:
FileLink('./0.csv')