## U-Net

In [2]:
import numpy as np
import torch
from torch import optim
from torchvision.transforms import transforms
from torch import nn
import torch.utils.data as data
from torch.utils.data import DataLoader
import PIL.Image as Image
import os
from matplotlib import pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
# 读取数据的路径
def make_dataset(root):
    imgs = []
    # 计算共有多少张原始图片
    n = len(os.listdir(root))//2
    for i in range(n):
        # 找到00i.png的路径
        img = os.path.join(root, '%03d.png'%i)
        # 找到00i_mask.png的路径
        mask = os.path.join(root, '%03d_mask.png'%i)
        # 添加至列表
        imgs.append((img, mask))
    return imgs

In [4]:
class LiverDataset(data.Dataset):
    
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y
    
    def __len__(self):
        return len(self.imgs)

In [5]:
x_transforms = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# mask只需转为Tensor
y_transforms = transforms.ToTensor()

In [6]:
batch_size = 4
liver_dataset = LiverDataset('../../dataset/u_net_data/liver/train', transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True)

In [7]:
for x, y in dataloaders:
    print(x.shape)
    print(y.shape)
    break

torch.Size([4, 3, 512, 512])
torch.Size([4, 1, 512, 512])


In [8]:
# U_Net模型中的双卷积网络结构
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            # 此处包含padding，为了使输出图像与输入图像大小相同
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_ch), 
            nn.ReLU(inplace=True), 
                        
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_ch), 
            nn.ReLU(inplace=True)
        )
        
    def forward(self, input):
        return self.conv(input)

In [9]:
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        
        # 特征图大小不变
        self.conv1 = DoubleConv(in_ch, 64)
        
        # 特征图大小长宽减半
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        self.conv5 = DoubleConv(512, 1024)
        
        
        # 长宽翻倍，通道数减半
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        
        
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)
        
    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        print(c5.shape)
        
        
        up_6 = self.up6(c5)
        # 通道维拼接 [N, C, H, W]
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = torch.sigmoid(c10)
        return out

In [10]:
# 输入图像有3个通道，标签图像有1个通道
net = Unet(3, 1).to(device)
# 类似逻辑回归，用一个输出来实现2分类
loss = torch.nn.BCELoss()
optimizer = optim.Adam(net.parameters())

In [12]:
X = torch.randn(4, 3, 512, 512).to(device)

In [13]:
net(X)

torch.Size([4, 1024, 32, 32])


tensor([[[[0.5538, 0.5421, 0.5571,  ..., 0.4913, 0.5537, 0.5039],
          [0.5598, 0.4745, 0.5655,  ..., 0.5953, 0.5919, 0.4724],
          [0.5025, 0.5493, 0.5103,  ..., 0.5193, 0.5087, 0.6323],
          ...,
          [0.5886, 0.4346, 0.4862,  ..., 0.4667, 0.5076, 0.4837],
          [0.4991, 0.3608, 0.5656,  ..., 0.3720, 0.4094, 0.4996],
          [0.4766, 0.5162, 0.5162,  ..., 0.4920, 0.5678, 0.5347]]],


        [[[0.5352, 0.5092, 0.5718,  ..., 0.4873, 0.5490, 0.4925],
          [0.5411, 0.3935, 0.5044,  ..., 0.4636, 0.4589, 0.4675],
          [0.4430, 0.4684, 0.5073,  ..., 0.5023, 0.5597, 0.5354],
          ...,
          [0.5352, 0.4410, 0.6902,  ..., 0.4788, 0.5981, 0.5758],
          [0.4649, 0.4438, 0.5642,  ..., 0.4355, 0.5632, 0.6019],
          [0.5060, 0.5188, 0.4841,  ..., 0.5058, 0.5956, 0.4391]]],


        [[[0.5218, 0.4170, 0.5241,  ..., 0.4606, 0.6171, 0.5564],
          [0.5342, 0.4510, 0.5429,  ..., 0.4428, 0.4940, 0.5485],
          [0.5396, 0.3643, 0.5061,  ..

In [14]:
def train_model(model, loss, optimizer, dataloaders, num_epochs=20):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-'*10)
        dt_size = len(dataloaders.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataloaders:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            l = loss(outputs, labels)
            l.backward()
            optimizer.step()
            epoch_loss += l.item()
            if step % 200 == 0:
                print('%d/%d, train_loss:%0.3f' % (step, (dt_size-1)//dataloaders.batch_size+1, l.item()))
        print('epoch %d loss:%0.3f' % (epoch, epoch_loss))
    return model

In [None]:
model = train_model(net, loss, optimizer, dataloaders)

Epoch 0/19
----------
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 1024, 32, 32])
torch.Size([4, 102

In [None]:
liver_val = LiverDataset('../../dataset/u_net_data/liver/val', transform=x_transforms, target_transform=y_transforms)
liver_val = DataLoader(liver_val, batch_size=1)
model.eval()
with torch.no_grad():
    for i, data in enumerate(dataloaders):
        # 左边真实，右边预测
        x, z = data
        y = model(x.to(device))
        img_y = torch.squeeze(y.cpu()).numpy()
        plt.subplot(1, 2, 1)
        z = torch.squeeze(z).numpy()
        plt.imshow(z)
        plt.axis('on')
        plt.subplot(1, 2, 2)
        plt.imshow(img_y)
        plt.axis('on')
        plt.pause(0.01)
        filename = '../../dataset/u_net_data/liver/predict/' + 'new_%d.png'%i
        Image.fromarray((img_y*255).astype('uint8')).convert('L').save(filename)