## 训练U-Net模型

In [2]:
import json
import os

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from dataset import BrainSegmentationDataset as Dataset
from logger import Logger
from loss import DiceLoss
from transform import transforms
from unet import UNet
from resunet import ResUNet
from utils import log_images, dsc

In [3]:
# 读取数据
def data_loaders(args):
    dataset_train, dataset_valid = datasets(args)

    def worker_init(worker_id):
        np.random.seed(42 + worker_id)

    loader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )

    return loader_train, loader_valid

# 数据集定义
def datasets(args):
    train = Dataset(
        images_dir=args.images,
        subset="train",
        image_size=args.image_size,
        transform=transforms(scale=args.aug_scale, angle=args.aug_angle, flip_prob=0.5),
    )
    valid = Dataset(
        images_dir=args.images,
        subset="validation",
        image_size=args.image_size,
        random_sampling=False,
    )
    return train, valid

# 数据处理
def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list


def log_loss_summary(logger, loss, step, prefix=""):
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)


def makedirs(args):
    os.makedirs(args.weights, exist_ok=True)
    os.makedirs(args.logs, exist_ok=True)


def snapshotargs(args):
    args_file = os.path.join(args.logs, "args.json")
    with open(args_file, "w") as fp:
        json.dump(vars(args), fp)

class Args:
    device = 'cuda:0'
    batch_size = 16
    epochs = 100
    lr = 0.0001
    workers = 8
    vis_images = 200
    vis_freq = 10
    weights = './weights'
    logs = './logs'
    images = '/home/mw/input/kaggle_3m4773/archive/archive/kaggle_3m'
    image_size = 256
    aug_scale = 0.05
    aug_angle = 15
args=Args()
loader_train, loader_valid = data_loaders(args)

reading train images...
preprocessing train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...
normalizing train volumes...
done creating train dataset
reading validation images...
preprocessing validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
done creating validation dataset


In [4]:
def train(args, optimizer, model):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loaders = {"train": loader_train, "valid": loader_valid}
    model.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0
    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0
    for epoch in range(args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = model(x)
                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )
                        if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                print("epoch {} | val_loss: {}".format(epoch + 1, np.mean(loss_valid)))
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                logger.scalar_summary("val_dsc", mean_dsc, step)
                print("epoch {} | val_dsc: {}".format(epoch+1, mean_dsc))
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(model.state_dict(), os.path.join(args.weights, "unet.pt"))
                loss_valid = []

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))

### 输入参数  
  
device: 设备编号  
batch_size: 批大小  
epochs: 训练轮数  
lr: 学习率  
vis_images: 可视化预测结果的数目 (在tensorboard中查看)  
vis_freq: 两次可视化预测结果的间隔  
weights: 训练后的模型参数路径    
images: 数据集路径   
image_size: 图像尺寸   
aug_scale: 数据增强(放缩)  
aug_angle: 数据增强(旋转)  

unet 模型

In [5]:
unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
optimizer = optim.Adam(unet.parameters(), lr=args.lr)
train(args, optimizer, unet)

epoch 1 | val_loss: 0.921807274931953
epoch 1 | val_dsc: 0.3258537218822525
epoch 2 | val_loss: 0.9002781311670939
epoch 2 | val_dsc: 0.6038292189175508
epoch 3 | val_loss: 0.8759276299249559
epoch 3 | val_dsc: 0.5678692500837503
epoch 4 | val_loss: 0.8460489511489868
epoch 4 | val_dsc: 0.744668116260923
epoch 5 | val_loss: 0.8028065789313543
epoch 5 | val_dsc: 0.7543791167003567
epoch 6 | val_loss: 0.7496414525168282
epoch 6 | val_dsc: 0.8011146934581495
epoch 7 | val_loss: 0.6965673622630891
epoch 7 | val_dsc: 0.8019404806791135
epoch 8 | val_loss: 0.652526353086744
epoch 8 | val_dsc: 0.7944773105237847
epoch 9 | val_loss: 0.5962302571251279
epoch 9 | val_dsc: 0.8110659103470164
epoch 10 | val_loss: 0.5769200523694357
epoch 10 | val_dsc: 0.7149988022182465
epoch 11 | val_loss: 0.5289075488135928
epoch 11 | val_dsc: 0.8158750934072589
epoch 12 | val_loss: 0.5089089870452881
epoch 12 | val_dsc: 0.8126189484555605
epoch 13 | val_loss: 0.48591749157224384
epoch 13 | val_dsc: 0.8126269273

优化器调整

In [6]:
args.weights = './sgdweights'
args.logs = './sgdlogs'
args.lr = 0.001
unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
optimizer = optim.SGD(unet.parameters(), lr=args.lr, weight_decay=0.0001, momentum=0.9)
train(args, optimizer, unet)

epoch 1 | val_loss: 0.9439293770563035
epoch 1 | val_dsc: 0.08842493219442307
epoch 2 | val_loss: 0.9249004835174197
epoch 2 | val_dsc: 0.1938810667292812
epoch 3 | val_loss: 0.8835468717983791
epoch 3 | val_dsc: 0.5082186329976641
epoch 4 | val_loss: 0.7628934894289289
epoch 4 | val_dsc: 0.7627984203287029
epoch 5 | val_loss: 0.6226285667646498
epoch 5 | val_dsc: 0.7498127354644857
epoch 6 | val_loss: 0.5618038574854533
epoch 6 | val_dsc: 0.753520729063664
epoch 7 | val_loss: 0.5558435689835322
epoch 7 | val_dsc: 0.5872649645781959
epoch 8 | val_loss: 0.5662362518764678
epoch 8 | val_dsc: 0.5797809237865156
epoch 9 | val_loss: 0.5468812442961193
epoch 9 | val_dsc: 0.5735789605373853
epoch 10 | val_loss: 0.4892008219446455
epoch 10 | val_dsc: 0.7944753810952956
epoch 11 | val_loss: 0.4409557126817249
epoch 11 | val_dsc: 0.7897980512396903
epoch 12 | val_loss: 0.46258695068813505
epoch 12 | val_dsc: 0.789408170274082
epoch 13 | val_loss: 0.42238350425447735
epoch 13 | val_dsc: 0.8001476

模型调整


In [7]:
args.weights = './resweights'
args.logs = './reslogs'
resunet = ResUNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
optimizer = optim.Adam(resunet.parameters(), lr=args.lr)
train(args, optimizer, resunet)

epoch 1 | val_loss: 0.6717041361899603
epoch 1 | val_dsc: 0.6245335025980845
epoch 2 | val_loss: 0.48789879537764047
epoch 2 | val_dsc: 0.7863545327105971
epoch 3 | val_loss: 0.44301135483242216
epoch 3 | val_dsc: 0.7920629108823141
epoch 4 | val_loss: 0.48925745771044776
epoch 4 | val_dsc: 0.7487530167548705
epoch 5 | val_loss: 0.4268521098863511
epoch 5 | val_dsc: 0.7942466841883222
epoch 6 | val_loss: 0.40247605244318646
epoch 6 | val_dsc: 0.8021607210232679
epoch 7 | val_loss: 0.37926592997142244
epoch 7 | val_dsc: 0.8040793640590831
epoch 8 | val_loss: 0.35958669299171087
epoch 8 | val_dsc: 0.8079866484460899
epoch 9 | val_loss: 0.394537695816585
epoch 9 | val_dsc: 0.80612418532658
epoch 10 | val_loss: 0.37678466240564984
epoch 10 | val_dsc: 0.8105120875750655
epoch 11 | val_loss: 0.3495638597579229
epoch 11 | val_dsc: 0.8110022855989858
epoch 12 | val_loss: 0.36205443881806876
epoch 12 | val_dsc: 0.8102956621539373
epoch 13 | val_loss: 0.34891607080187115
epoch 13 | val_dsc: 0.80