In [2]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

In [3]:
class Unet_dataset(Dataset):
    def __init__(self, img_path, label_path, transform=None):
        super().__init__()
        self.dataset = os.listdir(img_path)
        self.labels = os.listdir(label_path)
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = Image.open(os.path.join('./unet_train/bldg/src/', self.dataset[idx]))
        label = Image.open(os.path.join('./unet_train/bldg/label/', self.labels[idx]))
        label = torch.from_numpy(np.array(label)).float().unsqueeze(0)
        if self.transform is not None:
            image = self.transform(image)
        
        return image, label

transform = transforms.Compose([
    transforms.ToTensor(),
])
bldg_dataset = Unet_dataset('./unet_train/bldg/src/', './unet_train/bldg/label/', transform=transform)
bldg_dataloader = DataLoader(bldg_dataset, batch_size=8, num_workers=0, shuffle=True)

In [4]:
class BasicConv2d(nn.Module):
    def __init__(self, inp, oup):
        super(BasicConv2d, self).__init__()
        self.conv1 = nn.Conv2d(inp, oup, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(oup, oup, 3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(oup)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = F.elu(x, inplace=True)
        x = self.conv2(x)
        x = self.bn(x)
        return F.elu(x, inplace=True)

    
class Unet(nn.Module):
    def __init__(self, num_classes=1):
        super(Unet, self).__init__()
        self.conv1 = BasicConv2d(3, 32)
        self.conv2 = BasicConv2d(32, 64)
        self.conv3 = BasicConv2d(64, 128)
        self.conv4 = BasicConv2d(128, 256)
        self.conv5 = BasicConv2d(256, 512)
        
        self.conv6 = BasicConv2d(768, 256)
        self.conv7 = BasicConv2d(384, 128)
        self.conv8 = BasicConv2d(192, 64)
        self.conv9 = BasicConv2d(96, 32)
        
        self.MaxPool = nn.MaxPool2d(2, 2)
        
        self.conv10 = nn.Conv2d(32, num_classes, 1)
    
    def forward(self, x):
        dec1 = self.conv1(x)
        max_dec1 = self.MaxPool(dec1)
        dec2 = self.conv2(max_dec1)
        max_dec2 = self.MaxPool(dec2)
        dec3 = self.conv3(max_dec2)
        max_dec3 = self.MaxPool(dec3)
        dec4 = self.conv4(max_dec3)
        max_dec4 = self.MaxPool(dec4)
        
        center = self.conv5(dec4)
        up6 = torch.cat([
            dec4,  F.upsample_bilinear(center, dec4.size()[2:])], 1)
        dec6 = self.conv6(up6)
        up7 = torch.cat([
            dec3, F.upsample_bilinear(dec6, dec3.size()[2:])], 1)
        dec7 = self.conv7(up7)
        up8 = torch.cat([
            dec2, F.upsample_bilinear(dec7, dec2.size()[2:])], 1)
        dec8 = self.conv8(up8)
        up9 = torch.cat([
            dec1, F.upsample_bilinear(dec8, dec1.size()[2:])], 1)
        dec9 = self.conv9(up9)
        out = self.conv10(dec9)
        return nn.Sigmoid()(out)

In [4]:
# train
from torch.optim import lr_scheduler
from torch.autograd import Variable

model = Unet().cuda()
optimizer = torch.optim.Adam(model.parameters(), 1e-3, (0.9, 0.999), eps=1e-08, weight_decay=1e-4)
criterion = torch.nn.BCELoss()
num_epochs = 20
lambda1 = lambda epoch: pow((1-((epoch-1)/num_epochs)),0.9)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)


for epoch in range(num_epochs):
    print("---training -epoch", epoch, "---")
    scheduler.step(epoch)
    epoch_loss = []
    acc = []
    usedLr = 0
    
    for param_group in optimizer.param_groups:
        print("LEARNING RATE: ", param_group['lr'])
        usedLr = float(param_group['lr'])
    
    model.train()
    for i, (images, labels) in enumerate(bldg_dataloader):
        l = len(bldg_dataloader)
        images = images.to('cuda').float()
        targets = labels.to('cuda').float()
        
        outputs = model(images)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss.append(loss.item())

        total = len(outputs.view(-1))
        targets = targets.squeeze()
        outputs = outputs.squeeze() > 0.5
        outputs = outputs.float()
        correct = len(outputs[targets==outputs])
        acc_ = correct / total
        
        acc.append(acc_)
        
        if i % 49 == 0:
            average = sum(epoch_loss) / len(epoch_loss)
            accuracy = sum(acc) / len(acc)
            print("loss: {} epoch: {}, acc: {}, {}/{}".format(average, epoch, accuracy, i, l))
    
    # save model
    if epoch % 2 == 0:
        torch.save(model, './model_{}.pth'.format(epoch))

---training -epoch 0 ---
LEARNING RATE:  0.0010448895099824027




loss: 0.7344312071800232 epoch: 0, acc: 0.4294583067602041, 0/6250
loss: 0.6462302422523498 epoch: 0, acc: 0.6548760363520408, 49/6250
loss: 0.6016978811133992 epoch: 0, acc: 0.6964160901102867, 98/6250
loss: 0.5779351022195172 epoch: 0, acc: 0.7127094384523238, 147/6250
loss: 0.5606895848579213 epoch: 0, acc: 0.7225204336572829, 196/6250
loss: 0.5445439551419359 epoch: 0, acc: 0.7316498364780364, 245/6250
loss: 0.5340236428430525 epoch: 0, acc: 0.7374241314640263, 294/6250
loss: 0.5302544780248819 epoch: 0, acc: 0.7391965456013434, 343/6250
loss: 0.522740699104377 epoch: 0, acc: 0.7437787182501816, 392/6250
loss: 0.5199420622031613 epoch: 0, acc: 0.7456528810442445, 441/6250
loss: 0.5174453810858872 epoch: 0, acc: 0.7468168665442247, 490/6250
loss: 0.5135858945272587 epoch: 0, acc: 0.7491275155895691, 539/6250
loss: 0.510937422868149 epoch: 0, acc: 0.7496106743312775, 588/6250
loss: 0.5079553439325674 epoch: 0, acc: 0.7508119225967227, 637/6250
loss: 0.5061397598473439 epoch: 0, acc: 

loss: 0.463563850173669 epoch: 0, acc: 0.7730996644534734, 5782/6250
loss: 0.4633402817426518 epoch: 0, acc: 0.7732097796278362, 5831/6250
loss: 0.46343648011810906 epoch: 0, acc: 0.7731371645604445, 5880/6250
loss: 0.46335949621102057 epoch: 0, acc: 0.7730794584512766, 5929/6250
loss: 0.4633703155872674 epoch: 0, acc: 0.773056222928208, 5978/6250
loss: 0.46308138104577135 epoch: 0, acc: 0.7732321908368939, 6027/6250
loss: 0.4631103902996342 epoch: 0, acc: 0.7731693095889107, 6076/6250
loss: 0.46299445219739294 epoch: 0, acc: 0.7732459558221857, 6125/6250
loss: 0.46305017459971703 epoch: 0, acc: 0.7732143986769813, 6174/6250
loss: 0.4629560588071348 epoch: 0, acc: 0.7732673838346094, 6223/6250


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


---training -epoch 1 ---
LEARNING RATE:  0.001
loss: 0.3819466531276703 epoch: 1, acc: 0.8198939732142857, 0/6250
loss: 0.4291510689258575 epoch: 1, acc: 0.7855763711734697, 49/6250
loss: 0.42089326544241473 epoch: 1, acc: 0.791187041814832, 98/6250
loss: 0.42941557595858704 epoch: 1, acc: 0.7882863132584114, 147/6250
loss: 0.4334263520192374 epoch: 1, acc: 0.7874424158940225, 196/6250
loss: 0.44454529457460573 epoch: 1, acc: 0.7807513892561597, 245/6250
loss: 0.4442360648159253 epoch: 1, acc: 0.7809624021748532, 294/6250
loss: 0.4394691471703524 epoch: 1, acc: 0.7834047408934506, 343/6250
loss: 0.4349480712307622 epoch: 1, acc: 0.787805778621086, 392/6250
loss: 0.4387032965122305 epoch: 1, acc: 0.7852807664532162, 441/6250
loss: 0.43990013103499676 epoch: 1, acc: 0.7848360126732724, 490/6250
loss: 0.4398967282363662 epoch: 1, acc: 0.7859011095757751, 539/6250
loss: 0.4379232124582009 epoch: 1, acc: 0.7868485105603152, 588/6250
loss: 0.4388748506393552 epoch: 1, acc: 0.786634682282004,

loss: 0.43548499302172605 epoch: 1, acc: 0.7869549325891526, 5733/6250
loss: 0.43556114413603025 epoch: 1, acc: 0.7868793512063201, 5782/6250
loss: 0.4356197054192567 epoch: 1, acc: 0.7868731027687893, 5831/6250
loss: 0.4356001068303788 epoch: 1, acc: 0.7868614549562284, 5880/6250
loss: 0.4356933093266986 epoch: 1, acc: 0.7868558912089819, 5929/6250
loss: 0.4354904037536556 epoch: 1, acc: 0.7869740310394423, 5978/6250
loss: 0.4355602824730476 epoch: 1, acc: 0.786895421745185, 6027/6250
loss: 0.4356651959454364 epoch: 1, acc: 0.7868547422632974, 6076/6250
loss: 0.43552920764961 epoch: 1, acc: 0.7868786998723036, 6125/6250
loss: 0.4354393377936321 epoch: 1, acc: 0.7869305034134104, 6174/6250
loss: 0.4351601599018954 epoch: 1, acc: 0.7871066404098196, 6223/6250
---training -epoch 2 ---
LEARNING RATE:  0.0009548853816214997
loss: 0.42339926958084106 epoch: 2, acc: 0.7332190688775511, 0/6250
loss: 0.43169010639190675 epoch: 2, acc: 0.7870461973852041, 49/6250
loss: 0.42457626218145544 epoch

loss: 0.4214471314249415 epoch: 2, acc: 0.7938817345619062, 5194/6250
loss: 0.4214640297308212 epoch: 2, acc: 0.7939151973003155, 5243/6250
loss: 0.42127481477372747 epoch: 2, acc: 0.794046923467219, 5292/6250
loss: 0.4213621388048279 epoch: 2, acc: 0.7940210063720593, 5341/6250
loss: 0.42107910617696753 epoch: 2, acc: 0.794196716557902, 5390/6250
loss: 0.42101548057706917 epoch: 2, acc: 0.7941690799855103, 5439/6250
loss: 0.4209627842241829 epoch: 2, acc: 0.7941665341641082, 5488/6250
loss: 0.42082177368139007 epoch: 2, acc: 0.794188666799332, 5537/6250
loss: 0.42090337854315474 epoch: 2, acc: 0.794119435581028, 5586/6250
loss: 0.420994490051921 epoch: 2, acc: 0.7940529623164369, 5635/6250
loss: 0.420979379590604 epoch: 2, acc: 0.7941090562626769, 5684/6250
loss: 0.4209475390792392 epoch: 2, acc: 0.7941163150356718, 5733/6250
loss: 0.4211553479769476 epoch: 2, acc: 0.7940154770718053, 5782/6250
loss: 0.4210167633847363 epoch: 2, acc: 0.7941799990918113, 5831/6250
loss: 0.4210010366595

loss: 0.4105563077601985 epoch: 3, acc: 0.8017098199218503, 4655/6250
loss: 0.41044552119391375 epoch: 3, acc: 0.8018136308970347, 4704/6250
loss: 0.4104980998270944 epoch: 3, acc: 0.8017421063217198, 4753/6250
loss: 0.41060735273802995 epoch: 3, acc: 0.8017346917198356, 4802/6250
loss: 0.4105643266527887 epoch: 3, acc: 0.8017859595066285, 4851/6250
loss: 0.4105536975567011 epoch: 3, acc: 0.8017821219521553, 4900/6250
loss: 0.4104417340442388 epoch: 3, acc: 0.8018711952110416, 4949/6250
loss: 0.41052568603883055 epoch: 3, acc: 0.801799157264491, 4998/6250
loss: 0.4105821323777341 epoch: 3, acc: 0.8017163795010523, 5047/6250
loss: 0.4107826000478938 epoch: 3, acc: 0.8016286438657668, 5096/6250
loss: 0.4109333906969902 epoch: 3, acc: 0.801517465362192, 5145/6250
loss: 0.4110290027817817 epoch: 3, acc: 0.8014779161183266, 5194/6250
loss: 0.4112764237964317 epoch: 3, acc: 0.801329014850186, 5243/6250
loss: 0.4111628354288287 epoch: 3, acc: 0.8013216037643119, 5292/6250
loss: 0.411065548334

loss: 0.40154246266733135 epoch: 4, acc: 0.8075394068792342, 4116/6250
loss: 0.4017744253689546 epoch: 4, acc: 0.8075417162409242, 4165/6250
loss: 0.40154770962001307 epoch: 4, acc: 0.8077349792786318, 4214/6250
loss: 0.4012364885240346 epoch: 4, acc: 0.8079085423250107, 4263/6250
loss: 0.4007425453058672 epoch: 4, acc: 0.808227272784402, 4312/6250
loss: 0.4010161464433504 epoch: 4, acc: 0.8080776094557045, 4361/6250
loss: 0.4013429649412375 epoch: 4, acc: 0.8078711654088904, 4410/6250
loss: 0.4010749114201208 epoch: 4, acc: 0.8080524759125889, 4459/6250
loss: 0.40154515302321364 epoch: 4, acc: 0.8077826816958475, 4508/6250
loss: 0.40147826664015723 epoch: 4, acc: 0.8077673609517563, 4557/6250
loss: 0.4014151991513382 epoch: 4, acc: 0.8077261743892653, 4606/6250
loss: 0.40162669349448993 epoch: 4, acc: 0.8075234174845423, 4655/6250
loss: 0.40157333240308873 epoch: 4, acc: 0.807523508747613, 4704/6250
loss: 0.40147448371551997 epoch: 4, acc: 0.8075931387900779, 4753/6250
loss: 0.4013299

loss: 0.3962084959446117 epoch: 5, acc: 0.8116622828533341, 3528/6250
loss: 0.39645552423211566 epoch: 5, acc: 0.8114939326371627, 3577/6250
loss: 0.3963725846604864 epoch: 5, acc: 0.8115396940349247, 3626/6250
loss: 0.396422599471692 epoch: 5, acc: 0.8115298892186781, 3675/6250
loss: 0.39620377352973757 epoch: 5, acc: 0.8115498840056169, 3724/6250
loss: 0.3962301412231087 epoch: 5, acc: 0.8115319026681634, 3773/6250
loss: 0.3962798167859077 epoch: 5, acc: 0.8115098285893466, 3822/6250
loss: 0.3963814904328343 epoch: 5, acc: 0.8114665509637498, 3871/6250
loss: 0.39626200166900255 epoch: 5, acc: 0.8115261517408247, 3920/6250
loss: 0.39632414858647497 epoch: 5, acc: 0.8114676891498117, 3969/6250
loss: 0.3962258799221718 epoch: 5, acc: 0.8115027588581171, 4018/6250
loss: 0.39608077447029927 epoch: 5, acc: 0.811572622774914, 4067/6250
loss: 0.39615975368115575 epoch: 5, acc: 0.8114932982658408, 4116/6250
loss: 0.39580417754226044 epoch: 5, acc: 0.8116848544300267, 4165/6250
loss: 0.3954439

loss: 0.39388238619203153 epoch: 6, acc: 0.8131749452629081, 2989/6250
loss: 0.39401533837360003 epoch: 6, acc: 0.8131074176647455, 3038/6250
loss: 0.3940156168145142 epoch: 6, acc: 0.8131184996945466, 3087/6250
loss: 0.3938735829470908 epoch: 6, acc: 0.8132563272086206, 3136/6250
loss: 0.39371538582700494 epoch: 6, acc: 0.8133999239187786, 3185/6250
loss: 0.39332711546000965 epoch: 6, acc: 0.8136537101920555, 3234/6250
loss: 0.3935388338941532 epoch: 6, acc: 0.8136151611127848, 3283/6250
loss: 0.3933269825115456 epoch: 6, acc: 0.813699351639564, 3332/6250
loss: 0.3934370923692626 epoch: 6, acc: 0.8136496875763718, 3381/6250
loss: 0.39336082854377946 epoch: 6, acc: 0.8136320019459798, 3430/6250
loss: 0.3934146657928653 epoch: 6, acc: 0.8136423667998988, 3479/6250
loss: 0.3933579660139262 epoch: 6, acc: 0.8135741560633184, 3528/6250
loss: 0.39330042726237924 epoch: 6, acc: 0.8136933455388521, 3577/6250
loss: 0.3931607724813709 epoch: 6, acc: 0.8138138066110961, 3626/6250
loss: 0.3931565

loss: 0.38425494979737196 epoch: 7, acc: 0.8184755836526028, 2450/6250
loss: 0.3843778041303158 epoch: 7, acc: 0.8183061662946445, 2499/6250
loss: 0.3841686662024917 epoch: 7, acc: 0.8184700915937944, 2548/6250
loss: 0.3847285746220353 epoch: 7, acc: 0.8180028163834834, 2597/6250
loss: 0.3847102592261188 epoch: 7, acc: 0.8180051101121419, 2646/6250
loss: 0.38486773553318193 epoch: 7, acc: 0.8179420531716503, 2695/6250
loss: 0.38487301656674383 epoch: 7, acc: 0.8179741335998214, 2744/6250
loss: 0.3847877158541379 epoch: 7, acc: 0.8179862966823505, 2793/6250
loss: 0.3847857370914978 epoch: 7, acc: 0.8180058427889936, 2842/6250
loss: 0.3849141448993818 epoch: 7, acc: 0.8179835471533017, 2891/6250
loss: 0.3852400049055486 epoch: 7, acc: 0.8178391592366716, 2940/6250
loss: 0.3853236146595167 epoch: 7, acc: 0.8176916984655479, 2989/6250
loss: 0.3853220029091435 epoch: 7, acc: 0.8177814487869421, 3038/6250
loss: 0.3856759570886434 epoch: 7, acc: 0.8176425555877495, 3087/6250
loss: 0.385434090

loss: 0.3749893955559037 epoch: 8, acc: 0.8239512658121242, 1911/6250
loss: 0.37606412249064947 epoch: 8, acc: 0.8234900757695996, 1960/6250
loss: 0.3764437321554962 epoch: 8, acc: 0.823212148957065, 2009/6250
loss: 0.37639871326507207 epoch: 8, acc: 0.8231334299671186, 2058/6250
loss: 0.3761681535650589 epoch: 8, acc: 0.8230371313091774, 2107/6250
loss: 0.3761443633363376 epoch: 8, acc: 0.8227821887078033, 2156/6250
loss: 0.3765136172567728 epoch: 8, acc: 0.8224651538461326, 2205/6250
loss: 0.37715719351483024 epoch: 8, acc: 0.8219864303814661, 2254/6250
loss: 0.3785720881876639 epoch: 8, acc: 0.8214926860229781, 2303/6250
loss: 0.37841959899276356 epoch: 8, acc: 0.8215374275241348, 2352/6250
loss: 0.3781660160950876 epoch: 8, acc: 0.821789548914384, 2401/6250
loss: 0.3783850422967263 epoch: 8, acc: 0.8216401237858986, 2450/6250
loss: 0.378854440754652 epoch: 8, acc: 0.8215677295918378, 2499/6250
loss: 0.37869150828454673 epoch: 8, acc: 0.8215302847980607, 2548/6250
loss: 0.3785216220

loss: 0.37211307329759924 epoch: 9, acc: 0.8251116013366386, 1372/6250
loss: 0.37159926892793965 epoch: 9, acc: 0.8251866019515673, 1421/6250
loss: 0.3716716289945722 epoch: 9, acc: 0.8254484188464232, 1470/6250
loss: 0.37211252045082416 epoch: 9, acc: 0.8249936014701947, 1519/6250
loss: 0.37237757966729623 epoch: 9, acc: 0.8251141260011431, 1568/6250
loss: 0.37299753982477046 epoch: 9, acc: 0.8249339888487478, 1617/6250
loss: 0.3725651585741106 epoch: 9, acc: 0.8251070269842735, 1666/6250
loss: 0.37314385296407837 epoch: 9, acc: 0.8247727962606607, 1715/6250
loss: 0.37315691381101906 epoch: 9, acc: 0.8245482304554274, 1764/6250
loss: 0.3723846279092897 epoch: 9, acc: 0.8249379658794007, 1813/6250
loss: 0.37324490020240564 epoch: 9, acc: 0.8244920445785017, 1862/6250
loss: 0.37360487145734134 epoch: 9, acc: 0.8242168998180813, 1911/6250
loss: 0.37341496943363667 epoch: 9, acc: 0.8243536550424622, 1960/6250
loss: 0.37383121728007473 epoch: 9, acc: 0.8242122951092762, 2009/6250
loss: 0.3

loss: 0.365870458779821 epoch: 10, acc: 0.8284946288427801, 784/6250
loss: 0.365958340734029 epoch: 10, acc: 0.828284904761141, 833/6250
loss: 0.36462259819121484 epoch: 10, acc: 0.8289378917354748, 882/6250
loss: 0.3660416170283193 epoch: 10, acc: 0.8285651566571511, 931/6250
loss: 0.3653074768462069 epoch: 10, acc: 0.8288318696171133, 980/6250
loss: 0.3656457779621615 epoch: 10, acc: 0.8284095614504906, 1029/6250
loss: 0.36616968024484087 epoch: 10, acc: 0.8285338019808823, 1078/6250
loss: 0.3670804332545463 epoch: 10, acc: 0.828274553359409, 1127/6250
loss: 0.36695436819337035 epoch: 10, acc: 0.8280167489282244, 1176/6250
loss: 0.3675965806666828 epoch: 10, acc: 0.8278312623415491, 1225/6250
loss: 0.36716045511703865 epoch: 10, acc: 0.8280908457132844, 1274/6250
loss: 0.3675268421067752 epoch: 10, acc: 0.8279517879541347, 1323/6250
loss: 0.3674680778494909 epoch: 10, acc: 0.8279468869724225, 1372/6250
loss: 0.36800737466280636 epoch: 10, acc: 0.8275151144483359, 1421/6250
loss: 0.36

loss: 0.3525521884093413 epoch: 11, acc: 0.8374294106367213, 147/6250
loss: 0.35672057219568243 epoch: 11, acc: 0.8333223398813842, 196/6250
loss: 0.3572405069703009 epoch: 11, acc: 0.8321838024462835, 245/6250
loss: 0.3534389480695886 epoch: 11, acc: 0.8338052766127642, 294/6250
loss: 0.35448650552263095 epoch: 11, acc: 0.8336787357339377, 343/6250
loss: 0.3506204178011751 epoch: 11, acc: 0.835983724059758, 392/6250
loss: 0.35519017063384684 epoch: 11, acc: 0.8336657905349643, 441/6250
loss: 0.3580107092250389 epoch: 11, acc: 0.8315175795050711, 490/6250
loss: 0.3578774496085114 epoch: 11, acc: 0.832216357680225, 539/6250
loss: 0.36074498596843113 epoch: 11, acc: 0.8311546751022142, 588/6250
loss: 0.3624441276616811 epoch: 11, acc: 0.8297293370595696, 637/6250
loss: 0.3619873545402349 epoch: 11, acc: 0.8303582887508169, 686/6250
loss: 0.3626690236244189 epoch: 11, acc: 0.8296788917142098, 735/6250
loss: 0.3626847719310955 epoch: 11, acc: 0.8296112334264923, 784/6250
loss: 0.3637507084

loss: 0.36271229382241543 epoch: 11, acc: 0.82924952076783, 5831/6250
loss: 0.36271368726502673 epoch: 11, acc: 0.829255468972308, 5880/6250
loss: 0.3628271790871532 epoch: 11, acc: 0.8292223649314271, 5929/6250
loss: 0.3627280277798339 epoch: 11, acc: 0.8293585905260242, 5978/6250
loss: 0.36295445501685697 epoch: 11, acc: 0.8292257454820738, 6027/6250
loss: 0.3627510629723059 epoch: 11, acc: 0.8293332596789327, 6076/6250
loss: 0.3627614433853758 epoch: 11, acc: 0.8292982859654933, 6125/6250
loss: 0.3629158763533179 epoch: 11, acc: 0.829242782006527, 6174/6250
loss: 0.36279862904609933 epoch: 11, acc: 0.8292943062784038, 6223/6250
---training -epoch 12 ---
LEARNING RATE:  0.00048740643917899543
loss: 0.22594307363033295 epoch: 12, acc: 0.9095583545918368, 0/6250
loss: 0.36294845163822176 epoch: 12, acc: 0.8278402024872448, 49/6250
loss: 0.3803116361601184 epoch: 12, acc: 0.820199866168058, 98/6250
loss: 0.3656493815014491 epoch: 12, acc: 0.8284083943438015, 147/6250
loss: 0.36691675944

loss: 0.35956780317027026 epoch: 12, acc: 0.8322102426599858, 5194/6250
loss: 0.3595838516689183 epoch: 12, acc: 0.8321979211360613, 5243/6250
loss: 0.3595692839690965 epoch: 12, acc: 0.8321915510341452, 5292/6250
loss: 0.35943135307723106 epoch: 12, acc: 0.8323297985102188, 5341/6250
loss: 0.35930263118390315 epoch: 12, acc: 0.8323383308265582, 5390/6250
loss: 0.35931707733338153 epoch: 12, acc: 0.8323943762075156, 5439/6250
loss: 0.35927538223677286 epoch: 12, acc: 0.832391975146375, 5488/6250
loss: 0.3591940396594179 epoch: 12, acc: 0.8324436229326524, 5537/6250
loss: 0.3592542139438241 epoch: 12, acc: 0.8324000841541105, 5586/6250
loss: 0.3591843516472661 epoch: 12, acc: 0.832408126022379, 5635/6250
loss: 0.3592586723761592 epoch: 12, acc: 0.8323676262020329, 5684/6250
loss: 0.35907643348534724 epoch: 12, acc: 0.8323436001264408, 5733/6250
loss: 0.35886484724277035 epoch: 12, acc: 0.8324469508347411, 5782/6250
loss: 0.35885986231078376 epoch: 12, acc: 0.8324068944667329, 5831/6250


loss: 0.3533449833276843 epoch: 13, acc: 0.8352570255902376, 4557/6250
loss: 0.3533112603046254 epoch: 13, acc: 0.8352704341365231, 4606/6250
loss: 0.3534201952001509 epoch: 13, acc: 0.8352543527025416, 4655/6250
loss: 0.35372298985529654 epoch: 13, acc: 0.8350120752422245, 4704/6250
loss: 0.35368905083398183 epoch: 13, acc: 0.8350222781647885, 4753/6250
loss: 0.3537899345274322 epoch: 13, acc: 0.8348954735753209, 4802/6250
loss: 0.35378029097796027 epoch: 13, acc: 0.8348836354071582, 4851/6250
loss: 0.3538485661611171 epoch: 13, acc: 0.834854117017585, 4900/6250
loss: 0.3541059119246825 epoch: 13, acc: 0.834698642596242, 4949/6250
loss: 0.3541989189463273 epoch: 13, acc: 0.8346336638182219, 4998/6250
loss: 0.3543393661337425 epoch: 13, acc: 0.8345894183851593, 5047/6250
loss: 0.3543863948379775 epoch: 13, acc: 0.8346200010454062, 5096/6250
loss: 0.3541316422831095 epoch: 13, acc: 0.8347404214506541, 5145/6250
loss: 0.35409460998103526 epoch: 13, acc: 0.8347870126605724, 5194/6250
loss

loss: 0.3524086869791344 epoch: 14, acc: 0.8359160198844287, 3920/6250
loss: 0.3521886982023716 epoch: 14, acc: 0.836025929283082, 3969/6250
loss: 0.35226191298326587 epoch: 14, acc: 0.836110429481807, 4018/6250
loss: 0.35242208814103754 epoch: 14, acc: 0.8361081848334398, 4067/6250
loss: 0.35233664864261144 epoch: 14, acc: 0.8362106579883768, 4116/6250
loss: 0.3525096671731782 epoch: 14, acc: 0.8361273382503979, 4165/6250
loss: 0.35252661569616267 epoch: 14, acc: 0.8361037333638995, 4214/6250
loss: 0.3524062647772551 epoch: 14, acc: 0.836067735754013, 4263/6250
loss: 0.3520070071594084 epoch: 14, acc: 0.8362369328962768, 4312/6250
loss: 0.3521833491982642 epoch: 14, acc: 0.8362514202186228, 4361/6250
loss: 0.3522331814387844 epoch: 14, acc: 0.8362393325844444, 4410/6250
loss: 0.352329989940329 epoch: 14, acc: 0.8362121408873967, 4459/6250
loss: 0.35232209360210875 epoch: 14, acc: 0.8361935293599329, 4508/6250
loss: 0.352409340892381 epoch: 14, acc: 0.8362288526746556, 4557/6250
loss: 

loss: 0.34730891718435014 epoch: 15, acc: 0.83883881599134, 3283/6250
loss: 0.34738395819113554 epoch: 15, acc: 0.8387589396694769, 3332/6250
loss: 0.34748763312370307 epoch: 15, acc: 0.8387889678597447, 3381/6250
loss: 0.34773203359031357 epoch: 15, acc: 0.8386697972074465, 3430/6250
loss: 0.3475887029570923 epoch: 15, acc: 0.8387202048788261, 3479/6250
loss: 0.34772172998600015 epoch: 15, acc: 0.8386499414382143, 3528/6250
loss: 0.34789641407230437 epoch: 15, acc: 0.8383647943810173, 3577/6250
loss: 0.34821046120146926 epoch: 15, acc: 0.8381830255817689, 3626/6250
loss: 0.3483125711624005 epoch: 15, acc: 0.8382006761661409, 3675/6250
loss: 0.3483561625236633 epoch: 15, acc: 0.8381830720962878, 3724/6250
loss: 0.3479690012469232 epoch: 15, acc: 0.8384179923024949, 3773/6250
loss: 0.3482718855481699 epoch: 15, acc: 0.8382427389463007, 3822/6250
loss: 0.3485186051073184 epoch: 15, acc: 0.8380476291635259, 3871/6250
loss: 0.34864419784413037 epoch: 15, acc: 0.8379245530102815, 3920/6250


loss: 0.3441707319500854 epoch: 16, acc: 0.8398377002712448, 2646/6250
loss: 0.34384281026304475 epoch: 16, acc: 0.8401748064912394, 2695/6250
loss: 0.3438088903227356 epoch: 16, acc: 0.8402710691517039, 2744/6250
loss: 0.3437411941897255 epoch: 16, acc: 0.8401780278872916, 2793/6250
loss: 0.3440616526734313 epoch: 16, acc: 0.8400351900949876, 2842/6250
loss: 0.3437713110796273 epoch: 16, acc: 0.8401314716595927, 2891/6250
loss: 0.3438718987870119 epoch: 16, acc: 0.8399730805436341, 2940/6250
loss: 0.3437201032944547 epoch: 16, acc: 0.8400973056446666, 2989/6250
loss: 0.34307813216595323 epoch: 16, acc: 0.8404519447910588, 3038/6250
loss: 0.34330029468356094 epoch: 16, acc: 0.8404565960203514, 3087/6250
loss: 0.34359753738695703 epoch: 16, acc: 0.8403559699376852, 3136/6250
loss: 0.3440842622796358 epoch: 16, acc: 0.8402169123004678, 3185/6250
loss: 0.34397762333729537 epoch: 16, acc: 0.8402907237928997, 3234/6250
loss: 0.34384747374752916 epoch: 16, acc: 0.8402897986663088, 3283/6250


loss: 0.33637460110273526 epoch: 17, acc: 0.844754543608488, 2009/6250
loss: 0.33619119630038247 epoch: 17, acc: 0.8447954748580768, 2058/6250
loss: 0.33594555937220294 epoch: 17, acc: 0.8449609178348755, 2107/6250
loss: 0.335737031991138 epoch: 17, acc: 0.8449576406210082, 2156/6250
loss: 0.3359942301391799 epoch: 17, acc: 0.8448854004890994, 2205/6250
loss: 0.3357068554855503 epoch: 17, acc: 0.8449949481594186, 2254/6250
loss: 0.33584858463533845 epoch: 17, acc: 0.8448503239084527, 2303/6250
loss: 0.3360534505708847 epoch: 17, acc: 0.8447961263898877, 2352/6250
loss: 0.33667839814805667 epoch: 17, acc: 0.8444794805216518, 2401/6250
loss: 0.33672672352076843 epoch: 17, acc: 0.8445826992535833, 2450/6250
loss: 0.3368370159983635 epoch: 17, acc: 0.844632158801019, 2499/6250
loss: 0.33708358752746215 epoch: 17, acc: 0.8445761386768309, 2548/6250
loss: 0.33686104924450294 epoch: 17, acc: 0.844680090208715, 2597/6250
loss: 0.33708383560090593 epoch: 17, acc: 0.8445042133217315, 2646/6250
l

loss: 0.33491013381421264 epoch: 18, acc: 0.8451181702025762, 1372/6250
loss: 0.33418115702590023 epoch: 18, acc: 0.8454532871297243, 1421/6250
loss: 0.3333491735935292 epoch: 18, acc: 0.8459360847234446, 1470/6250
loss: 0.33421992839461095 epoch: 18, acc: 0.845758633557246, 1519/6250
loss: 0.3347911174231612 epoch: 18, acc: 0.8456815438591779, 1568/6250
loss: 0.3338806488227712 epoch: 18, acc: 0.8460637961051997, 1617/6250
loss: 0.33499050049650936 epoch: 18, acc: 0.8454744219364029, 1666/6250
loss: 0.33468810372089314 epoch: 18, acc: 0.8456647344935564, 1715/6250
loss: 0.334759922815281 epoch: 18, acc: 0.8456825458804265, 1764/6250
loss: 0.33551859865098566 epoch: 18, acc: 0.8454880935974158, 1813/6250
loss: 0.33606912453633386 epoch: 18, acc: 0.844965324847732, 1862/6250
loss: 0.3367146400921826 epoch: 18, acc: 0.8445009497531029, 1911/6250
loss: 0.336422936796139 epoch: 18, acc: 0.8446250534172741, 1960/6250
loss: 0.33702495278633054 epoch: 18, acc: 0.8444762495716565, 2009/6250
lo

loss: 0.3299789134838173 epoch: 19, acc: 0.8482279874840565, 735/6250
loss: 0.3293006064880426 epoch: 19, acc: 0.8486177572143514, 784/6250
loss: 0.32911317126392176 epoch: 19, acc: 0.848688945845875, 833/6250
loss: 0.3285139731699991 epoch: 19, acc: 0.8488846666303138, 882/6250
loss: 0.32943156442321164 epoch: 19, acc: 0.8488619202039728, 931/6250
loss: 0.32943622562103925 epoch: 19, acc: 0.8486034560995661, 980/6250
loss: 0.3306377911958301 epoch: 19, acc: 0.8482663645327679, 1029/6250
loss: 0.33045768574142587 epoch: 19, acc: 0.8483766428547792, 1078/6250
loss: 0.32996792047687457 epoch: 19, acc: 0.8485476424404541, 1127/6250
loss: 0.3309942159038939 epoch: 19, acc: 0.847972697701373, 1176/6250
loss: 0.3308717232154983 epoch: 19, acc: 0.847892864100423, 1225/6250
loss: 0.3306317900091994 epoch: 19, acc: 0.8479420283738505, 1274/6250
loss: 0.33092625377266427 epoch: 19, acc: 0.847861727625105, 1323/6250
loss: 0.3310939280420694 epoch: 19, acc: 0.8478253343569324, 1372/6250
loss: 0.33

In [17]:
from skimage import io
import matplotlib.pyplot as plt

class test_dataset(Dataset):
    def __init__(self, img_path, transform=None):
        super().__init__()
        self.dataset = os.listdir(img_path)
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = Image.open(os.path.join('./unet_train/test/', self.dataset[idx]))
        if self.transform is not None:
            image = self.transform(image)
        
        return image


palette = {0: (0, 0, 0),        # Undefined (black)
           1: (0, 255, 0),      # Trees (green)
           2: (0, 0, 255),      # Buildings (red)
           3: (255, 0, 0),      # Water (blue)
           4: (0, 255, 255)}    # Roads (yellow)
    
    
def convert_to_color(arr_2d, palette=palette):
    # Numeric labels to RGB-color encoding
    arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)
    for c, i in palette.items():
        m = arr_2d == c
        arr_3d[m] = i

    return arr_3d
    
    
tst_dataset =  test_dataset('./unet_train/test/', transform=transform)
tst_dataloader = DataLoader(tst_dataset, batch_size=1, num_workers=0, shuffle=False)

model = torch.load('model_16.pth')
model.eval()
for i, image in enumerate(tst_dataloader):
    image = image.to('cuda').float()
    outputs = model(image)
    outputs = outputs.cpu().detach().squeeze().numpy()
    print(outputs)
    #outputs = outputs > 0.5
    
    #RGB = convert_to_color(outputs)
    #io.imsave('./unet_train/pre_{}.png'.format(i), RGB)
    
#outputs = (outputs>0.5) * 255


[[0.11257628 0.03616188 0.02804123 ... 0.02947    0.03467409 0.10364516]
 [0.04674973 0.00798338 0.00509478 ... 0.00591495 0.0074211  0.04167047]
 [0.04395681 0.00666742 0.00372995 ... 0.00461472 0.00634709 0.04029362]
 ...
 [0.0498231  0.00814993 0.00507668 ... 0.00788751 0.00966827 0.05611226]
 [0.05469604 0.01023869 0.00728711 ... 0.0102478  0.01153989 0.05756445]
 [0.12737349 0.0475118  0.04091154 ... 0.05024809 0.05177438 0.13881734]]
[[0.11158665 0.03492364 0.02709022 ... 0.0282689  0.03314438 0.10267923]
 [0.04709937 0.00775517 0.00496569 ... 0.00591159 0.00735112 0.04264004]
 [0.04440732 0.0064431  0.00360508 ... 0.00469352 0.00635594 0.04223432]
 ...
 [0.04103954 0.00585152 0.00354172 ... 0.01046022 0.01188981 0.06026326]
 [0.04732029 0.00795202 0.0055149  ... 0.01298653 0.01361096 0.06062268]
 [0.11742714 0.04095023 0.03464228 ... 0.05705291 0.05665679 0.14229617]]
[[0.11845004 0.03752591 0.02986662 ... 0.03222935 0.0371578  0.10828361]
 [0.0493969  0.00791963 0.00519598 ... 