## 训练U-Net模型

新建python 3.10环境（以conda为例）

```
conda create -n hw4 python=3.10 -y
conda activate hw4
```

安装torch，注意cuda版本适配
```
pip install torch==2.0.* torchvision==0.15.* --index-url https://download.pytorch.org/whl/cu117
```

安装其他依赖库
```
pip install ipykernel==6.26.* matplotlib==3.8.* medpy==0.4.* scipy==1.11.* numpy==1.23.* scikit-image==0.22.* imageio==2.31.* tensorboard==2.15.* tqdm==4.* -i https://pypi.tuna.tsinghua.edu.cn/simple
```

In [5]:
import json
import os

from types import SimpleNamespace
import tqdm
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from dataset import BrainSegmentationDataset as Dataset
from logger import Logger
from loss import DiceLoss
from transform import transforms
from unet import UNet
from utils import log_images, dsc

2024-03-07 10:34:04.288417: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### 输入参数  
  
device: 设备编号  
batch_size: 批大小  
epochs: 训练轮数  
lr: 学习率  
vis_images: 可视化预测结果的数目 (在tensorboard中查看)  
vis_freq: 两次可视化预测结果的间隔  
weights: 训练后的模型参数路径    
images: 数据集路径   
image_size: 图像尺寸   
aug_scale: 数据增强(放缩)  
aug_angle: 数据增强(旋转)  

In [6]:
args = SimpleNamespace(
    device = 'cuda:0',
    batch_size = 16,
    epochs = 100,
    lr = 0.0001,
    workers = 0,
    vis_images = 200,
    vis_freq = 10,
    weights = './weights',
    logs = './logs',
    images = './kaggle_3m',
    image_size = 256,
    aug_scale = 0.05,
    aug_angle = 15,
)

In [23]:
args

namespace(device='cuda:0',
          batch_size=16,
          epochs=100,
          lr=0.0001,
          workers=0,
          vis_images=200,
          vis_freq=10,
          weights='./weights',
          logs='./logs',
          images='./kaggle_3m',
          image_size=256,
          aug_scale=0.05,
          aug_angle=15)

In [7]:
# 查看 numpy version
import numpy as np
np.__version__
%pip install medpy==0.4.* -i https://pypi.tuna.tsinghua.edu.cn/simple

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [8]:
# 读取数据
def worker_init(worker_id):
    np.random.seed(42 + worker_id)

def data_loaders(args):
    dataset_train, dataset_valid = datasets(args)

    loader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )

    return dataset_train, dataset_valid, loader_train, loader_valid

# 数据集定义
def datasets(args):
    train = Dataset(
        images_dir=args.images,
        subset="train",
        image_size=args.image_size,
        transform=transforms(scale=args.aug_scale, angle=args.aug_angle, flip_prob=0.5),
    )
    valid = Dataset(
        images_dir=args.images,
        subset="validation",
        image_size=args.image_size,
        random_sampling=False,
    )
    return train, valid

# 数据处理
def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list


def log_loss_summary(logger, loss, step, prefix=""):
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)


def makedirs(args):
    os.makedirs(args.weights, exist_ok=True)
    os.makedirs(args.logs, exist_ok=True)


def snapshotargs(args):
    args_file = os.path.join(args.logs, "args.json")
    with open(args_file, "w") as fp:
        json.dump(vars(args), fp)

In [9]:
makedirs(args)
snapshotargs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

dataset_train, dataset_valid, loader_train, loader_valid = data_loaders(args)
loaders = {"train": loader_train, "valid": loader_valid}


reading train images...
preprocessing train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...
normalizing train volumes...
done creating train dataset
reading validation images...
preprocessing validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
done creating validation dataset


In [10]:
unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
unet.to(device)

dsc_loss = DiceLoss()
best_validation_dsc = 0.0

optimizer = optim.Adam(unet.parameters(), lr=args.lr)

logger = Logger(args.logs)
loss_train = []
loss_valid = []

step = 0


In [12]:
for epoch in range(args.epochs):
    for phase in ["train", "valid"]:
        if phase == "train":
            unet.train()
        else:
            unet.eval()

        validation_pred = []
        validation_true = []

        for i, data in enumerate(tqdm.tqdm(loaders[phase])):
            if phase == "train":
                step += 1

            x, y_true = data
            x, y_true = x.to(device), y_true.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                y_pred = unet(x)

                loss = dsc_loss(y_pred, y_true)

                if phase == "valid":
                    loss_valid.append(loss.item())
                    y_pred_np = y_pred.detach().cpu().numpy()
                    validation_pred.extend(
                        [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                    )
                    y_true_np = y_true.detach().cpu().numpy()
                    validation_true.extend(
                        [y_true_np[s] for s in range(y_true_np.shape[0])]
                    )
                    if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                        if i * args.batch_size < args.vis_images:
                            tag = "image/{}".format(i)
                            num_images = args.vis_images - i * args.batch_size
                            logger.image_list_summary(
                                tag,
                                log_images(x, y_true, y_pred)[:num_images],
                                step,
                            )

                if phase == "train":
                    loss_train.append(loss.item())
                    loss.backward()
                    optimizer.step()

            if phase == "train" and (step + 1) % 10 == 0:
                log_loss_summary(logger, loss_train, step)
                loss_train = []

        if phase == "valid":
            log_loss_summary(logger, loss_valid, step, prefix="val_")
            print("epoch {} | val_loss: {}".format(epoch + 1, np.mean(loss_valid)))
            mean_dsc = np.mean(
                dsc_per_volume(
                    validation_pred,
                    validation_true,
                    loader_valid.dataset.patient_slice_index,
                )
            )
            logger.scalar_summary("val_dsc", mean_dsc, step)
            print("epoch {} | val_dsc: {}".format(epoch+1, mean_dsc))
            if mean_dsc > best_validation_dsc:
                best_validation_dsc = mean_dsc
                torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
            loss_valid = []

print("Best validation mean DSC: {:4f}".format(best_validation_dsc))

In [29]:
!tensorboard --logdir=./logs

2024-03-07 23:06:07.195576: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6008/ (Press CTRL+C to quit)
^C


In [18]:
!tensorboard --logdir=./logs/hparam_tuning

In [28]:
%tensorboard --logdir logs

Launching TensorBoard...