In [2]:
import os
import json
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

In [4]:
import torchvision

In [5]:
import torch
import torch.nn as nn
#model = torchvision.models.segmentation.fcn_resnet101(pretrained=True)

In [6]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, mode=None, transform = None):
        super().__init__()

        self.img_dir = "D:\Desktop\project\AugusTOOTH\project\data\img_data"
        self.json_dir = "D:\Desktop\project\AugusTOOTH\project\data\json_data"

        self.transform = transform
        self.to_tensor = ToTensor()

        if mode == 'train':
            self.path = os.path.join(self.img_dir, 'train')
            self.json_path = os.path.join(self.json_dir, 'train')
        elif mode =='val':
            self.path = os.path.join(self.img_dir, 'validation')
            self.json_path = os.path.join(self.json_dir, 'validation')
        else:
            self.path = os.path.join(self.img_dir, 'test')
            self.json_path = os.path.join(self.json_dir, 'test')
        
        self.filenames = os.listdir(self.path)
        self.json_filenames = os.listdir(self.json_path)

    def __getitem__(self, idx):
        label_img = self.label_image(idx)
        img_path = os.path.join(self.path, self.filenames[idx])
        img = np.asarray_chkfinite(Image.open(img_path)).astype(np.float32)
        
        
        img = img / 255.0
        
        label_img = label_img / 255.0
        
        
        input = {'data' : img, 'label' : label_img}
        
        input = self.to_tensor(input)
        # print(type(label_img))
        if self.transform :
            input['data'] = self.transform(input['data'])
            input['label'] = self.transform(input['label'])
            
        
        return input

    
    def __len__(self):
        return len(self.filenames)

    def label_image(self, idx):
        img_path = os.path.join(self.path, self.filenames[idx])
        img = np.asarray_chkfinite(Image.open(img_path))
        with open(os.path.join(self.json_path, self.json_filenames[idx]), 'r') as f:
            json_data = json.load(f)
        
        img2 = img.copy()
        for i in range(0,len(json_data['annotations'])):
            point = np.array(json_data['annotations'][i]['points'])
            label_image = cv2.polylines(img, [point], True, (0, 0, 0))
            label_image = cv2.fillPoly(label_image, [point], (0, 50, 150))
        cv2.addWeighted(img, 0.5, img2, 0.5, 0, img2)    
        img2 = img.astype(np.float32)
        return img2
class ToTensor(object):
    def __call__(self, data):
        label, input = data['label'], data['data']

        # Image의 numpy 차원 = (Y, X, CH)
        # Image의 tensor 차원 = (CH, Y, X)
        label = label.transpose((2, 0, 1)).astype(np.float32)
        input = input.transpose((2, 0, 1)).astype(np.float32)

        data = {'data': torch.from_numpy(input), 'label': torch.from_numpy(label) }

        return data

In [7]:
lr = 1e-3
batch_size = 4
num_epoch = 5

In [12]:
mode = "train"

In [14]:
mode = 'val'

In [13]:
if mode == 'train':
    transforms_train = transforms.Compose([transforms.ToPILImage(), transforms.Resize((512, 512)),transforms.ToTensor(),
         ])
    dataset_train = Dataset(mode = mode, transform = transforms_train)
    load_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)

In [15]:
if mode == 'val':
    transforms_val = transforms.Compose([transforms.ToPILImage(), transforms.Resize((512, 512)), transforms.ToTensor()])
    dataset_val = Dataset(mode = mode, transform = transforms_val)
    load_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True, num_workers=0)

In [16]:
num_data_train = len(dataset_train)
num_data_val = len(dataset_val)

num_batch_train = np.ceil(num_data_train / batch_size)
num_batch_val = np.ceil(num_data_val / batch_size)

In [None]:
if mode == 'test':
    transforms_test = transforms.Compose([transforms.Normalize(0.5, 0.5),transforms.ToTensor()])
    dataset_test = Dataset(mode = 'test', transforms = transforms_test)
    load_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers = 0)

2704, 4064, 3
``` Python
i=0
for batch, data in enumerate(load_train, 1):
    print("dd")
    label = data
    i+=1
    if i == 3:
        break
```

In [13]:
i=0
for batch, data in enumerate(load_train, 1):
    print("dd")
    label = data
    i+=1
    if i == 3:
        break

dd
dd
dd


In [14]:
np.shape(label['label'])

torch.Size([4, 3, 512, 512])

In [23]:
net = UNet()

In [24]:
## loss func
loss_func = nn.BCEWithLogitsLoss()
## optimizer
optim = torch.optim.Adam(net.parameters(), lr=lr)

In [17]:
writer_train = SummaryWriter(log_dir=os.path.join('tensorboard', 'train') )
writer_val = SummaryWriter(log_dir=os.path.join('tensorboard', 'val'))

In [18]:
st_epoch = 0
train_continue = 'off'

In [20]:
mode = 'train'

In [None]:
if mode=='train':
    for epoch in range(0, num_epoch):
        net.train()
        loss_arr = []
        for batch, data in enumerate(load_train, 1):
            label = data['label']
            input = data['data']
            
            output = net(input)
            
            optim.zero_grad()
            
            loss = loss_func(output, label)
            loss.backward()
            
            optim.step()
            
            loss_arr += [loss.item()]
            
            print("TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f" %
                  (epoch, num_epoch, batch, num_batch_train, np.mean(loss_arr)))
            
        with torch.no_grad():
                net.eval()
                loss_arr = []

                for batch, data in enumerate(load_val, 1):
                    # forward pass
                    label = data['label']
                    input = data['data']

                    output = net(input)

                    # loss function
                    loss = fn_loss(output, label)

                    loss_arr += [loss.item()]

                    print("VALID: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f" %
                          (epoch, num_epoch, batch, num_batch_val, np.mean(loss_arr)))

In [22]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
	
        # Convolution, Batch_normalization, ReLU
        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr

        # Contracting path (encoder)
        self.enc1_1 = CBR2d(in_channels=3, out_channels=64) # , kernel_size=3, stride=1, padding=1, bias=True
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64)

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)

        # Expansive path (decoder)
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)

        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        # output map channel을 2개 + nn.CrossEntropyLoss == output map channel을 1개 + nn.BCELoss(binary cross entropy loss)
        # self.fc = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1, padding=0, bias=True)
        self.fc = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1)
        # concatenation, dim=[0: batch, 1: channel, 2: height, 3: width]
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x