In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [32]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
inputs = torch.rand(size=(8, 3, 224, 224))
conv = DoubleConv(in_channels=3, out_channels=32)
outputs = conv(inputs)
print(f'inputs.shape  : {inputs.shape}')
print(f'outputs.shape : {outputs.shape}')

inputs.shape  : torch.Size([8, 3, 224, 224])
outputs.shape : torch.Size([8, 32, 224, 224])


In [33]:
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

inputs = torch.rand(size=(8, 16, 224, 224))
down = Down(in_channels=16, out_channels=32)
outputs = down(inputs)
print(f'inputs.shape  : {inputs.shape}')
print(f'outputs.shape : {outputs.shape}')

inputs.shape  : torch.Size([8, 16, 224, 224])
outputs.shape : torch.Size([8, 32, 112, 112])


In [34]:
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 逆卷积, kernel_size 和 stride 是相对于正卷积的过程而说的
        self.up = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        
        x = torch.cat([x1, x2], dim=1)
        x = self.up(x)  # 这一步就将 x1 的通道数减半了
        x = self.conv(x)
        
        return x


inputs = torch.rand(size=(8, 16, 224, 224))
up = Up(in_channels=16*2, out_channels=16//2)
outputs = up(inputs, inputs)
print(f'inputs.shape  : {inputs.shape}')
print(f'outputs.shape : {outputs.shape}')

inputs.shape  : torch.Size([8, 16, 224, 224])
outputs.shape : torch.Size([8, 8, 448, 448])


In [35]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

inputs = torch.rand(size=(8, 16, 224, 224))
out = OutConv(in_channels=16, out_channels=32)
outputs = out(inputs)
print(f'inputs.shape  : {inputs.shape}')
print(f'outputs.shape : {outputs.shape}')

inputs.shape  : torch.Size([8, 16, 224, 224])
outputs.shape : torch.Size([8, 32, 224, 224])


## UNet Model

In [36]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super(UNet, self).__init__()


        self.inc = DoubleConv(in_channels=in_channels, out_channels=32)
        self.down1 = Down(in_channels=32, out_channels=64)
        self.down2 = Down(in_channels=64, out_channels=128)
        self.down3 = Down(in_channels=128, out_channels=256)
        self.down4 = Down(in_channels=256, out_channels=512)
        
        self.bridge = DoubleConv(in_channels=512, out_channels=512)
        
        self.up4 = Up(in_channels=512 * 2, out_channels=256)
        self.up3 = Up(in_channels=256 * 2, out_channels=128)
        self.up2 = Up(in_channels=128 * 2, out_channels=64)
        self.up1 = Up(in_channels=64 * 2, out_channels=32)
        
        self.outc = DoubleConv(in_channels=32, out_channels=out_channels)

    def forward(self, x):
        x = self.inc(x)
        
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        bridge = self.bridge(d4)

        u4 = self.up4(bridge, d4)
        u3 = self.up3(u4, d3)
        u2 = self.up2(u3, d2)
        u1 = self.up1(u2, d1)
        
        out = self.outc(u1)
        
        return out

    def predict(self, inputs):
        """
        单个样本的预测
        :param inputs: (input_channels, height, width)
        :return: mask_pred
        """
        pred = self(inputs.unsqueeze(0)).squeeze(0)
        _, pred = torch.max(pred.squeeze(0), dim=0)
        pred = pred.detach()
        return pred

In [37]:
inputs = torch.rand(size=(8, 3, 224, 224))
net = UNet(in_channels=3, out_channels=2)
outputs = net(inputs)

print(f'inputs.shape  : {inputs.shape}')
print(f'outputs.shape : {outputs.shape}')

inputs.shape  : torch.Size([8, 3, 224, 224])
outputs.shape : torch.Size([8, 2, 224, 224])


## DataLoader

In [38]:
import os
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor


class BasicDataset(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        files = os.listdir(os.path.join(self.data_dir, 'imgs'))
        self.files = [item.split('.')[0] for item in files]
        self.convert = ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        img = self._load_img(self.files[i], is_mask=False)
        mask = self._load_img(self.files[i], is_mask=True)

        img = self.convert(img)
        mask = self.convert(mask)
        
        # img = nn.functional.pad(img, (1, 1, 0, 0))
        # mask = nn.functional.pad(mask, (1, 1, 0, 0))
        
        # 去掉 mask 通道的维度
        mask = torch.squeeze(mask, dim=0)

        return img, mask

    def _load_img(self, file_name, is_mask=False):
        data_dir = self.data_dir

        if is_mask:
            file_name = file_name + '_mask.gif'
            data_dir = os.path.join(data_dir, 'masks')
        else:
            file_name = file_name + '.jpg'
            data_dir = os.path.join(data_dir, 'imgs')
        
        obj = Image.open(os.path.join(data_dir, file_name))
        obj = obj.resize(size=(480, 320))
        
        return obj

In [39]:
data_dir = '/Volumes/SSD/SSD/blueberry/datasets/UNet'  # your data path

# 创建数据集对象
dataset = BasicDataset(data_dir=data_dir)
# 创建数据加载器  ==> 关于多进程 dataloader 报错 ==> 要在 py 文件中使用 __main__
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True)

In [40]:
for images, masks in dataloader:
    print(f'images.shape : {images.shape}')
    print(f'masks.shape  : {masks.shape}')
    break

images.shape : torch.Size([1, 3, 320, 480])
masks.shape  : torch.Size([1, 320, 480])


## 损失函数

In [41]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
        self.n_classes = None

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        self.n_classes = inputs.size(1)

        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes

In [42]:
class CrossEntropyDiceLoss(nn.CrossEntropyLoss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dice_loss = DiceLoss()

    def forward(self, pred, label):
        label = label.long()
        loss = super().forward(pred, label)
        loss += self.dice_loss(pred, label, softmax=True)

        return loss

In [43]:
# 测试损失函数
loss_fn = nn.CrossEntropyLoss()

pred = torch.rand(size=(8, 2, 224, 224), dtype=torch.float)
label = torch.ones(size=(8, 224, 224), dtype=torch.long)
loss_fn(pred, label)

tensor(0.7135)

## 定义优化器

In [44]:
from torch import optim


optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9, weight_decay=0.0001)

## 训练的逻辑

In [45]:
epochs = 2  # 在数据集上训练 2 轮
iter_num = 1
device = torch.device('cuda')  # gpu 设备

net.train()  # 设置为训练模型
# net.to(device)  # 把网络移动到 显存 上

# 如果不把 网络 和 数据移动到显存中，那么在训练的时候就是使用 cpu 训练
for epoch in range(1, epochs+1):
    for inputs, labels in dataloader:
        # 把数据移动到显存中
        # inputs, label = inputs.to(device), labels.to(device)
        
        outputs = net(inputs)
        loss = loss_fn(outputs, labels.long())

        # 梯度清零
        optimizer.zero_grad()

        # 计算梯度
        loss.backward()

        # 更新梯度
        optimizer.step()

        print(f'epoch {epoch}, iter_num {iter_num}, loss {loss}')
        iter_num += 1

epoch 1, iter_num 1, loss 0.7786043882369995
epoch 1, iter_num 2, loss 0.732941746711731
epoch 1, iter_num 3, loss 0.6942446231842041
epoch 1, iter_num 4, loss 0.6103910803794861
epoch 1, iter_num 5, loss 0.5705145001411438
epoch 1, iter_num 6, loss 0.5364541411399841
epoch 1, iter_num 7, loss 0.5060492157936096
epoch 1, iter_num 8, loss 0.48114022612571716
epoch 1, iter_num 9, loss 0.4522033631801605
epoch 1, iter_num 10, loss 0.44052189588546753
epoch 1, iter_num 11, loss 0.4210132360458374
epoch 1, iter_num 12, loss 0.4077134132385254
epoch 1, iter_num 13, loss 0.3952310383319855
epoch 1, iter_num 14, loss 0.38453152775764465
epoch 1, iter_num 15, loss 0.37258830666542053
epoch 1, iter_num 16, loss 0.3610352575778961
epoch 1, iter_num 17, loss 0.34523364901542664
epoch 1, iter_num 18, loss 0.3330628275871277
epoch 1, iter_num 19, loss 0.31670883297920227
epoch 1, iter_num 20, loss 0.315682053565979
epoch 1, iter_num 21, loss 0.311465859413147
epoch 1, iter_num 22, loss 0.29090732336

KeyboardInterrupt: 