In [2]:
import os
import torch
import shutil
import argparse
import math
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from dataset import MyDataset
import transforms as T
from model import UNet
from utils import (
    compute_gray,
    train_one_epoch,
    evaluate,
    plot,
    plot_lr_decay,
    plt_loss_iou
)


# 训练集预处理
class SegmentationPresetTrain:
    def __init__(self, base_size=600, rcrop_size=480, hflip_prob=0.5, vflip_prob=0.5, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        min_size = int(0.5 * base_size)
        max_size = int(1.5 * base_size)

        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))
        trans.extend([
            T.RandomCrop(rcrop_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)


# 测试集预处理
class SegmentationPresetTest:
    def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
    print("Using {} device training.".format(device))

    with open('./run_results/train_log_results.txt', "a") as f:
        info = f"[train hyper-parameters: {args}]\n"
        f.write(info)

    train_tf = SegmentationPresetTrain(base_size=args.base_size, rcrop_size=args.crop_size)
    test_tf = SegmentationPresetTest()

    num_classes = compute_gray()

    trainDataset = MyDataset(imgs_path='./data/train/images', txt_path='./data/grayList.txt', transform=train_tf)
    testDataset = MyDataset(imgs_path='./data/test/images', txt_path='./data/grayList.txt', transform=test_tf)

    num_workers = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8])
    print('Using %g dataloader workers' % num_workers)

    trainLoader = DataLoader(trainDataset, batch_size=args.batch_size, num_workers=num_workers, shuffle=True)
    testLoader = DataLoader(testDataset, batch_size=1, num_workers=num_workers, shuffle=False)

    model = UNet(num_classes=num_classes)
    model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-8)

    lf_plot = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf_plot)
    plot_lr_decay(scheduler, optimizer, args.epochs)

    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    best_mean_iou = 0.0
    train_loss_list = []
    test_loss_list = []
    train_miou_list = []
    test_miou_list = []
    for epoch in range(args.epochs):
        train_loss, test_loss, lr = train_one_epoch(model=model, optim=optimizer, train_loader=trainLoader, test_loader=testLoader, device=device)

        scheduler.step()

        train_miou, test_miou, test_confmat = evaluate(model=model, train_loader=trainLoader, test_loader=testLoader, device=device, num=num_classes)

        train_loss_list.append(train_loss)
        test_loss_list.append(test_loss)
        train_miou_list.append(train_miou)
        test_miou_list.append(test_miou)

        with open('./run_results/train_log_results.txt', "a") as f:
            info = f"[epoch: {epoch+1}]\n" + test_confmat + '\n\n'
            f.write(info)

        if test_miou > best_mean_iou:
            best_mean_iou = test_miou
            torch.save(model.state_dict(), './run_results/best_model.pth')

        print("[epoch:%d]" % (epoch + 1))
        print("learning rate:%.8f" % lr)
        print("train loss:%.4f \t train mean iou:%.4f" % (train_loss, train_miou))
        print("test loss:%.4f \t test mean iou:%.4f" % (test_loss, test_miou), end='\n\n')

    plt_loss_iou(train_loss_list, test_loss_list, train_miou_list, test_miou_list)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="unet segmentation")
    parser.add_argument("--base-size", default=400, type=int)
    parser.add_argument("--crop-size", default=240, type=int)
    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=10, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf', default=0.001, type=float)

    args, unknown = parser.parse_known_args()  # 允许未知参数
    print(args)

    if os.path.exists("./run_results"):
        shutil.rmtree('./run_results')
    os.mkdir("./run_results")

    main(args)
