In [1]:
%cd /home/mw/project/brain-seg/

/home/mw/project/brain-seg


## 训练U-Net模型

新建python 3.10环境（以conda为例）  

```
conda create -n hw4 python=3.10 -y  
conda activate hw4  
```

安装torch，注意cuda版本适配  
```
pip install torch==2.0.* torchvision==0.15.* --index-url https://download.pytorch.org/whl/cu117  
```

安装其他依赖库  
```
pip install ipykernel==6.26.* matplotlib==3.8.* medpy==0.4.* scipy==1.11.* numpy==1.23.* scikit-image==0.22.* imageio==2.31.* tensorboard==2.15.* tqdm==4.* -i https://pypi.tuna.tsinghua.edu.cn/simple  
```

In [29]:
# !pip install ipykernel==6.26.* matplotlib==3.8.* medpy==0.4.*  numpy==1.23.* scikit-image==0.22.* imageio==2.31.* tensorboard==2.15.* tqdm==4.* -i https://pypi.doubanio.com/simple/
# !pip install tensorflow -i https://pypi.doubanio.com/simple/

In [6]:
import json
import os

from types import SimpleNamespace
import tqdm
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from dataset import BrainSegmentationDataset as Dataset
from logger import Logger
from loss import DiceLoss
from transform import transforms
from unet import UNet
from utils import log_images, dsc

### 输入参数  
  
device: 设备编号  
batch_size: 批大小  
epochs: 训练轮数  
lr: 学习率  
vis_images: 可视化预测结果的数目 (在tensorboard中查看)  
vis_freq: 两次可视化预测结果的间隔  
weights: 训练后的模型参数路径    
images: 数据集路径   
image_size: 图像尺寸   
aug_scale: 数据增强(放缩)  
aug_angle: 数据增强(旋转)  

In [30]:
args = SimpleNamespace(
    device = 'cuda:0',
    batch_size = 16,
    epochs = 1,
    lr = 0.0001,
    workers = 0,
    vis_images = 200,
    vis_freq = 10,
    weights = './weights',
    logs = './logs',
    images = '/home/mw/input/brain_seg1509/archive/kaggle_3m',
    image_size = 256,
    aug_scale = 0.05,
    aug_angle = 15,
)

In [8]:
# 读取数据
def worker_init(worker_id):
    np.random.seed(42 + worker_id)

def data_loaders(args):
    dataset_train, dataset_valid = datasets(args)

    loader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )

    return dataset_train, dataset_valid, loader_train, loader_valid

# 数据集定义
def datasets(args):
    train = Dataset(
        images_dir=args.images,
        subset="train",
        image_size=args.image_size,
        transform=transforms(scale=args.aug_scale, angle=args.aug_angle, flip_prob=0.5),
    )
    valid = Dataset(
        images_dir=args.images,
        subset="validation",
        image_size=args.image_size,
        random_sampling=False,
    )
    return train, valid

# 数据处理
def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list


def log_loss_summary(logger, loss, step, prefix=""):
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)


def makedirs(args):
    os.makedirs(args.weights, exist_ok=True)
    os.makedirs(args.logs, exist_ok=True)


def snapshotargs(args):
    args_file = os.path.join(args.logs, "args.json")
    with open(args_file, "w") as fp:
        json.dump(vars(args), fp)

In [9]:
makedirs(args)
snapshotargs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

dataset_train, dataset_valid, loader_train, loader_valid = data_loaders(args)
loaders = {"train": loader_train, "valid": loader_valid}


reading train images...
preprocessing train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...
normalizing train volumes...
done creating train dataset
reading validation images...
preprocessing validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
done creating validation dataset


In [37]:
from torch.utils.tensorboard import SummaryWriter
from tensorboard.plugins.hparams import api as hp
import torch.nn as nn

device = torch.device("cpu" if not torch.cuda.is_available() else args.device)
# 定义超参数
HP_LR = hp.HParam('lr', hp.RealInterval(1e-5, 1e-3))
# HP_BATCH_SIZE = hp.HParam('batch_size', hp.Discrete([16, 32, 64]))
HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))
HP_LOSS_FN = hp.HParam('loss_fn', hp.Discrete(['bce_logits', 'dice', 'focal']))


METRIC_ACCURACY = 'accuracy'



In [34]:

# 实现模型训练和评估的函数

def train_eval_model(hparams, device, dataset_train, dataset_valid, loader_train, loader_valid):
    model = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    model.to(device)

    if hparams['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])

    else:  # 'sgd'
        optimizer = torch.optim.SGD(model.parameters(), lr=hparams['lr'], momentum=0.9)

    # # Adjust batch_size according to the hyperparameter
    # loader_train = DataLoader(dataset_train, batch_size=int(hparams[HP_BATCH_SIZE]), shuffle=True)
    # loader_valid = DataLoader(dataset_valid, batch_size=int(hparams[HP_BATCH_SIZE]), shuffle=False)

    # Training loop
    # For demonstration, let's use just one epoch
    for images, masks in loader_train:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        # loss = nn.BCEWithLogitsLoss()(outputs, masks)
        # 在训练函数中根据选择的损失函数来配置
        if hparams['loss_fn'] == 'bce_logits':
            loss = nn.BCEWithLogitsLoss()
        elif hparams['loss_fn'] == 'dice':
            loss = DiceLoss()
        elif hparams['loss_fn'] == 'focal':
            loss = FocalLoss()
        loss.backward()
        optimizer.step()

    # Evaluation loop
    with torch.no_grad():
        num_correct = 0
        num_pixels = 0
        for images, masks in loader_valid:
            images, masks = images.to(device), masks.to(device)
            outputs = torch.sigmoid(model(images))
            outputs = (outputs > 0.5).float()
            num_correct += (outputs == masks).sum()
            num_pixels += torch.numel(outputs)
    
    # Calculate accuracy
    accuracy = num_correct.item() / num_pixels
    return accuracy



In [None]:
# 遍历超参数的组合，训练并评估模型，记录每次运行的结果：
session_num = 0 

images_dir = args.images 
# 定义训练集和验证集的转换（数据增强）
train_transform = transforms(scale=0.05, angle=15, flip_prob=0.5)
valid_transform = None  # 对于验证集，可能不需要数据增强

# 初始化训练集
dataset_train = Dataset(
    images_dir=images_dir,
    subset="train",
    image_size=256,
    transform=train_transform,
    random_sampling=True,
    validation_cases=10,  # 假设使用10个案例进行验证
    seed=42
)

# 初始化验证集
dataset_valid = Dataset(
    images_dir=images_dir,
    subset="validation",
    image_size=256,
    transform=valid_transform,  # 通常，验证集不应用数据增强
    random_sampling=False,  # 验证时不需要随机采样
    validation_cases=10,  # 同样使用10个案例进行验证
    seed=42
)

In [39]:
# 遍历超参数的组合，训练并评估模型，记录每次运行的结果
for lr in np.linspace(HP_LR.domain.min_value, HP_LR.domain.max_value, num=2):
    for loss_fn in HP_LOSS_FN.domain.values:
        for optimizer in HP_OPTIMIZER.domain.values:
            hparams = {
                HP_LR.name: lr,
                HP_OPTIMIZER.name: optimizer,
                HP_LOSS_FN.name: loss_fn,
            }
            run_name = "run-%d" % session_num
            print('--- Starting trial: %s' % run_name)
            print({h: hparams[h] for h in hparams})
            with SummaryWriter('logs/hparam_tuning/' + run_name) as writer:
                accuracy = train_eval_model(hparams, device, dataset_train, dataset_valid, loader_train, loader_valid)
                writer.add_hparams(hparam_dict=hparams, metric_dict={METRIC_ACCURACY: accuracy})
            session_num += 1

--- Starting trial: run-11
{'lr': 1e-05, 'optimizer': 'adam', 'loss_fn': 'bce_logits'}
--- Starting trial: run-12
{'lr': 1e-05, 'optimizer': 'sgd', 'loss_fn': 'bce_logits'}
--- Starting trial: run-13
{'lr': 1e-05, 'optimizer': 'adam', 'loss_fn': 'dice'}
--- Starting trial: run-14
{'lr': 1e-05, 'optimizer': 'sgd', 'loss_fn': 'dice'}
--- Starting trial: run-15
{'lr': 1e-05, 'optimizer': 'adam', 'loss_fn': 'focal'}
--- Starting trial: run-16
{'lr': 1e-05, 'optimizer': 'sgd', 'loss_fn': 'focal'}
--- Starting trial: run-17
{'lr': 0.001, 'optimizer': 'adam', 'loss_fn': 'bce_logits'}
--- Starting trial: run-18
{'lr': 0.001, 'optimizer': 'sgd', 'loss_fn': 'bce_logits'}
--- Starting trial: run-19
{'lr': 0.001, 'optimizer': 'adam', 'loss_fn': 'dice'}
--- Starting trial: run-20
{'lr': 0.001, 'optimizer': 'sgd', 'loss_fn': 'dice'}
--- Starting trial: run-21
{'lr': 0.001, 'optimizer': 'adam', 'loss_fn': 'focal'}
--- Starting trial: run-22
{'lr': 0.001, 'optimizer': 'sgd', 'loss_fn': 'focal'}


In [22]:
dataset_train, dataset_valid, loader_train, loader_valid

(<dataset.BrainSegmentationDataset at 0x7f41025d2d60>,
 <dataset.BrainSegmentationDataset at 0x7f4102635100>,
 <torch.utils.data.dataloader.DataLoader at 0x7f416686c640>,
 <torch.utils.data.dataloader.DataLoader at 0x7f410006f400>)

In [31]:
unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
unet.to(device)

dsc_loss = DiceLoss()
best_validation_dsc = 0.0

optimizer = optim.SGD(unet.parameters(), lr=args.lr)

logger = Logger(args.logs)
loss_train = []
loss_valid = []

step = 0


In [32]:
for epoch in range(args.epochs):
    for phase in ["train", "valid"]:
        if phase == "train":
            unet.train()
        else:
            unet.eval()

        validation_pred = []
        validation_true = []

        for i, data in enumerate(tqdm.tqdm(loaders[phase])):
            if phase == "train":
                step += 1

            x, y_true = data
            x, y_true = x.to(device), y_true.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                y_pred = unet(x)

                loss = dsc_loss(y_pred, y_true)

                if phase == "valid":
                    loss_valid.append(loss.item())
                    y_pred_np = y_pred.detach().cpu().numpy()
                    validation_pred.extend(
                        [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                    )
                    y_true_np = y_true.detach().cpu().numpy()
                    validation_true.extend(
                        [y_true_np[s] for s in range(y_true_np.shape[0])]
                    )
                    if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                        if i * args.batch_size < args.vis_images:
                            tag = "image/{}".format(i)
                            num_images = args.vis_images - i * args.batch_size
                            logger.image_list_summary(
                                tag,
                                log_images(x, y_true, y_pred)[:num_images],
                                step,
                            )

                if phase == "train":
                    loss_train.append(loss.item())
                    loss.backward()
                    optimizer.step()

            if phase == "train" and (step + 1) % 10 == 0:
                log_loss_summary(logger, loss_train, step)
                loss_train = []

        if phase == "valid":
            log_loss_summary(logger, loss_valid, step, prefix="val_")
            print("epoch {} | val_loss: {}".format(epoch + 1, np.mean(loss_valid)))
            mean_dsc = np.mean(
                dsc_per_volume(
                    validation_pred,
                    validation_true,
                    loader_valid.dataset.patient_slice_index,
                )
            )
            logger.scalar_summary("val_dsc", mean_dsc, step)
            print("epoch {} | val_dsc: {}".format(epoch+1, mean_dsc))
            if mean_dsc > best_validation_dsc:
                best_validation_dsc = mean_dsc
                torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
            loss_valid = []

print("Best validation mean DSC: {:4f}".format(best_validation_dsc))

100%|██████████| 208/208 [01:59<00:00,  1.73it/s]
100%|██████████| 21/21 [01:39<00:00,  4.72s/it]


epoch 1 | val_loss: 0.9674926740782601
epoch 1 | val_dsc: 0.05731747174627673
Best validation mean DSC: 0.057317


In [33]:
%reload_ext tensorboard

In [36]:
%tensorboard --logdir logs

In [35]:
!kill 39618