In [1]:
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
import cv2
import pdb
from onehot import onehot
from torch.nn import functional as F
from datetime import datetime

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(), 
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [3]:
class TrainDatasets(Dataset):
    def __init__(self, transform = None):
        self.transform = transform
    
    def __len__(self):
        return len(os.listdir('train_image'))
    
    def __getitem__(self, idx):
        img_name = os.listdir('train_image')[idx]
        imgA = cv2.imread('train_image/'+img_name)
        imgA = cv2.resize(imgA, (800, 600))
        imgB = cv2.imread('train_label/'+img_name, 0)
        imgB = cv2.resize(imgB, (800, 600))
        imgB = imgB/255
        imgB = imgB.astype('uint8')
        imgB = onehot(imgB, 20)
        imgB = imgB.swapaxes(0, 2).swapaxes(1, 2)
        imgB = torch.FloatTensor(imgB)
        #print(imgB.shape)
        if self.transform:
            imgA = self.transform(imgA)    
        item = {'A':imgA, 'B':imgB}
        return item

In [4]:
train_sets = TrainDatasets(transform)
train_data = DataLoader(train_sets, batch_size = 4, shuffle = True)
if __name__ =='__main__':
    for batch in train_data:
        break

In [5]:
class TestDatasets(Dataset):
    def __init__(self, transform = None):
        self.transform = transform
    
    def __len__(self):
        return len(os.listdir('test_image'))
    
    def __getitem__(self, idx):
        img_name = os.listdir('test_image')[idx]
        imgA = cv2.imread('test_image/'+img_name)
        imgA = cv2.resize(imgA, (400, 300))
        imgB = cv2.imread('test_label/'+img_name, 0)
        imgB = cv2.resize(imgB, (400, 300))
        imgB = imgB/255
        imgB = imgB.astype('uint8')
        imgB = onehot(imgB, 2)
        imgB = imgB.swapaxes(0, 2).swapaxes(1, 2)
        imgB = torch.FloatTensor(imgB)
        #print(imgB.shape)
        if self.transform:
            imgA = self.transform(imgA)    
        item = {'A':imgA, 'B':imgB}
        return item

In [6]:
test_sets = TestDatasets(transform)
test_data = DataLoader(test_sets, batch_size = 4, shuffle = True)
if __name__ =='__main__':
    for batch in test_data:
        break

In [7]:
class FCN(nn.Module):
    def __init__(self):
        super(FCN,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.max_pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.batch_norm2 = nn.BatchNorm2d(128)
        self.max_pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(128, 256, 3, 1, 1)
        self.batch_norm3 = nn.BatchNorm2d(256)
        self.max_pool3 = nn.MaxPool2d(2, 2)
        ##self.demax1 = nn.MaxUnpool2d(2, 2)
        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.debatch1 = nn.BatchNorm2d(128)
        #self.demax2 = nn.MaxUnpool2d(2, 2)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.debatch2 = nn.BatchNorm2d(64)
        #self.demax3 = nn.MaxUnpool2d(2, 2)
        self.deconv3 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.debatch3 = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, 20, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.batch_norm1(x)
        x = self.max_pool1(x)
        x1 = x
        x = self.conv2(x)
        x = F.relu(x)
        x = self.batch_norm2(x)
        x = self.max_pool2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.batch_norm3(x)
        x = self.max_pool3(x)
        #x = self.demax1(x)
        x = self.deconv1(x)
        x = F.relu(x)
        x = self.debatch1(x)
        #x = self.demax2(x)
        x = self.deconv2(x)
        x = F.relu(x)
        x = self.debatch2(x)
        x = x + x1
        #x = self.demax3(x)
        x = self.deconv3(x)
        x = F.relu(x)
        x = self.debatch3(x)
        x = self.classifier(x)
        return x

In [8]:
fcn_net = FCN().cuda()
optimizer = torch.optim.Adam(fcn_net.parameters(), lr = 0.01)
criterion = nn.BCELoss().cuda()

In [9]:
saving_index =0
for epo in range(3):
    saving_index +=1
    index = 0
    epo_loss = 0
    for item in train_data:
        index += 1
        input = item['A']
        y = item['B']
        input = input.cuda()
        y = y.cuda()

        optimizer.zero_grad()
        output = fcn_net(input)
        output = nn.functional.sigmoid(output)
        loss = criterion(output, y)
        loss.backward()
        iter_loss = loss.data.item()
        epo_loss += iter_loss
        optimizer.step()
    print('epoch loss = %f'%(epo_loss/len(train_data)))



epoch loss = 0.046970
epoch loss = 0.020381
epoch loss = 0.018153
